diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d81681e471..eade4661a7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,6 +30,7 @@ jobs: - "test_agent_tool_graph.py" - "test_utils.py" - "test_tool_schema_parsing.py" + - "test_v1_routes.py" services: qdrant: image: qdrant/qdrant @@ -132,4 +133,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests + poetry run pytest -s -vv -k "not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests diff --git a/letta/functions/functions.py b/letta/functions/functions.py index 2e55bcdb1d..8ccb831bad 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -1,11 +1,8 @@ -import importlib import inspect -import os from textwrap import dedent # remove indentation from types import ModuleType from typing import Dict, List, Optional -from letta.constants import CLI_WARNING_PREFIX from letta.errors import LettaToolCreateError from letta.functions.schema_generator import generate_schema @@ -90,46 +87,3 @@ def load_function_set(module: ModuleType) -> dict: if len(function_dict) == 0: raise ValueError(f"No functions found in module {module}") return function_dict - - -def validate_function(module_name, module_full_path): - try: - file = os.path.basename(module_full_path) - spec = importlib.util.spec_from_file_location(module_name, module_full_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - except ModuleNotFoundError as e: - # Handle missing module imports - missing_package = str(e).split("'")[1] # Extract the name of the missing package - print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!") - return ( - False, - f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta.", - ) - except SyntaxError as e: - # Handle syntax errors in the module - return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}" - except Exception as e: - # Handle other general exceptions - return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}" - - return True, None - - -def load_function_file(filepath: str) -> dict: - file = os.path.basename(filepath) - module_name = file[:-3] # Remove '.py' from filename - try: - spec = importlib.util.spec_from_file_location(module_name, filepath) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - except ModuleNotFoundError as e: - # Handle missing module imports - missing_package = str(e).split("'")[1] # Extract the name of the missing package - print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!") - print( - f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta." - ) - # load all functions in the module - function_dict = load_function_set(module) - return function_dict diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 57204a85b6..41daae83aa 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -1,5 +1,6 @@ from typing import List, Optional +from composio.client.collections import ActionModel, AppModel from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.errors import LettaToolCreateError @@ -156,3 +157,39 @@ def add_base_tools( """ actor = server.get_user_or_default(user_id=user_id) return server.tool_manager.add_base_tools(actor=actor) + + +# Specific routes for Composio + + +@router.get("/composio/apps", response_model=List[AppModel], operation_id="list_composio_apps") +def list_composio_apps(server: SyncServer = Depends(get_letta_server)): + """ + Get a list of all Composio apps + """ + return server.get_composio_apps() + + +@router.get("/composio/apps/{composio_app_name}/actions", response_model=List[ActionModel], operation_id="list_composio_actions_by_app") +def list_composio_actions_by_app( + composio_app_name: str, + server: SyncServer = Depends(get_letta_server), +): + """ + Get a list of all Composio actions for a specific app + """ + return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name) + + +@router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool") +def add_composio_tool( + composio_action_name: str, + server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Add a new Composio tool by action name (Composio refers to each tool as an `Action`) + """ + actor = server.get_user_or_default(user_id=user_id) + tool_create = ToolCreate.from_composio(action=composio_action_name) + return server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor) diff --git a/letta/server/server.py b/letta/server/server.py index 832753115c..c8607e7f57 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -7,6 +7,8 @@ from datetime import datetime from typing import Callable, Dict, List, Optional, Tuple, Union +from composio.client import Composio +from composio.client.collections import ActionModel, AppModel from fastapi import HTTPException import letta.constants as constants @@ -227,6 +229,11 @@ def __init__( # Locks self.send_message_lock = Lock() + # Composio + self.composio_client = None + if tool_settings.composio_api_key: + self.composio_client = Composio(api_key=tool_settings.composio_api_key) + # Initialize the metadata store config = LettaConfig.load() if settings.letta_pg_uri_no_default: @@ -1750,3 +1757,18 @@ def get_agent_block_by_label(self, user_id: str, agent_id: str, label: str) -> B if block.label == label: return block return None + + # Composio wrappers + def get_composio_apps(self) -> List["AppModel"]: + """Get a list of all Composio apps with actions""" + apps = self.composio_client.apps.get() + apps_with_actions = [] + for app in apps: + if app.meta["actionsCount"] > 0: + apps_with_actions.append(app) + + return apps_with_actions + + def get_composio_actions_from_app_name(self, composio_app_name: str) -> List["ActionModel"]: + actions = self.composio_client.actions.get(apps=[composio_app_name]) + return actions diff --git a/poetry.lock b/poetry.lock index d2f0d6f4ad..30eddfadbe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -889,7 +889,7 @@ test = ["pytest"] name = "composio-core" version = "0.5.44" description = "Core package to act as a bridge between composio platform and other services." -optional = true +optional = false python-versions = "<4,>=3.9" files = [ {file = "composio_core-0.5.44-py3-none-any.whl", hash = "sha256:bb125794035a3c3c98dab1e72b45024068019c5eb3f29b9cc4eafc845320774b"}, @@ -925,7 +925,7 @@ tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "tra name = "composio-langchain" version = "0.5.44" description = "Use Composio to get an array of tools with your LangChain agent." -optional = true +optional = false python-versions = "<4,>=3.9" files = [ {file = "composio_langchain-0.5.44-py3-none-any.whl", hash = "sha256:4cb05d5b92faea32bc02c04e49b5dfae5858abe2f6469a81c673d7f754402375"}, @@ -958,7 +958,7 @@ yaml = ["PyYAML"] name = "cryptography" version = "43.0.3" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, @@ -2304,7 +2304,7 @@ type = ["pytest-mypy"] name = "inflection" version = "0.5.1" description = "A port of Ruby on Rails inflector to Python" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2"}, @@ -2565,7 +2565,7 @@ files = [ name = "jsonpatch" version = "1.33" description = "Apply JSON-Patches (RFC 6902)" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, @@ -2579,7 +2579,7 @@ jsonpointer = ">=1.9" name = "jsonpointer" version = "3.0.0" description = "Identify specific nodes in a JSON document (RFC 6901)" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, @@ -2590,7 +2590,7 @@ files = [ name = "jsonref" version = "1.1.0" description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"}, @@ -2601,7 +2601,7 @@ files = [ name = "jsonschema" version = "4.23.0" description = "An implementation of JSON Schema validation for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, @@ -2622,7 +2622,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2024.10.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, @@ -2705,7 +2705,7 @@ adal = ["adal (>=1.0.2)"] name = "langchain" version = "0.3.7" description = "Building applications with LLMs through composability" -optional = true +optional = false python-versions = "<4.0,>=3.9" files = [ {file = "langchain-0.3.7-py3-none-any.whl", hash = "sha256:cf4af1d5751dacdc278df3de1ff3cbbd8ca7eb55d39deadccdd7fb3d3ee02ac0"}, @@ -2760,7 +2760,7 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" name = "langchain-core" version = "0.3.19" description = "Building applications with LLMs through composability" -optional = true +optional = false python-versions = "<4.0,>=3.9" files = [ {file = "langchain_core-0.3.19-py3-none-any.whl", hash = "sha256:562b7cc3c15dfaa9270cb1496990c1f3b3e0b660c4d6a3236d7f693346f2a96c"}, @@ -2783,7 +2783,7 @@ typing-extensions = ">=4.7" name = "langchain-openai" version = "0.2.9" description = "An integration package connecting OpenAI and LangChain" -optional = true +optional = false python-versions = "<4.0,>=3.9" files = [ {file = "langchain_openai-0.2.9-py3-none-any.whl", hash = "sha256:2723015e56879f9e5edfcb175fdbec6c296c1b3bf65caad28579ce9c4d1bd652"}, @@ -2799,7 +2799,7 @@ tiktoken = ">=0.7,<1" name = "langchain-text-splitters" version = "0.3.2" description = "LangChain text splitting utilities" -optional = true +optional = false python-versions = "<4.0,>=3.9" files = [ {file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"}, @@ -2813,7 +2813,7 @@ langchain-core = ">=0.3.15,<0.4.0" name = "langchainhub" version = "0.1.21" description = "The LangChain Hub API client" -optional = true +optional = false python-versions = "<4.0,>=3.8.1" files = [ {file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"}, @@ -2829,7 +2829,7 @@ types-requests = ">=2.31.0.2,<3.0.0.0" name = "langsmith" version = "0.1.144" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." -optional = true +optional = false python-versions = "<4.0,>=3.8.1" files = [ {file = "langsmith-0.1.144-py3-none-any.whl", hash = "sha256:08ffb975bff2e82fc6f5428837c64c074ea25102d08a25e256361a80812c6100"}, @@ -4243,7 +4243,7 @@ xml = ["lxml (>=4.9.2)"] name = "paramiko" version = "3.5.0" description = "SSH2 protocol library" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "paramiko-3.5.0-py3-none-any.whl", hash = "sha256:1fedf06b085359051cd7d0d270cebe19e755a8a921cc2ddbfa647fb0cd7d68f9"}, @@ -5214,7 +5214,7 @@ model = ["milvus-model (>=0.1.0)"] name = "pynacl" version = "1.5.0" description = "Python binding to the Networking and Cryptography (NaCl) library" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, @@ -5262,7 +5262,7 @@ image = ["Pillow (>=8.0.0)"] name = "pyperclip" version = "1.9.0" description = "A cross-platform clipboard module for Python. (Only handles plain text for now.)" -optional = true +optional = false python-versions = "*" files = [ {file = "pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310"}, @@ -5327,7 +5327,7 @@ nodejs = ["nodejs-wheel-binaries"] name = "pysher" version = "1.0.8" description = "Pusher websocket client for python, based on Erik Kulyk's PythonPusherClient" -optional = true +optional = false python-versions = "*" files = [ {file = "Pysher-1.0.8.tar.gz", hash = "sha256:7849c56032b208e49df67d7bd8d49029a69042ab0bb45b2ed59fa08f11ac5988"}, @@ -5734,7 +5734,7 @@ prompt_toolkit = ">=2.0,<=3.0.36" name = "referencing" version = "0.35.1" description = "JSON Referencing + Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, @@ -5891,7 +5891,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "requests-toolbelt" version = "1.0.0" description = "A utility belt for advanced users of python-requests" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, @@ -5924,7 +5924,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rpds-py" version = "0.21.0" description = "Python bindings to Rust's persistent data structures (rpds)" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "rpds_py-0.21.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a017f813f24b9df929674d0332a374d40d7f0162b326562daae8066b502d0590"}, @@ -6051,7 +6051,7 @@ asn1crypto = ">=1.5.1" name = "semver" version = "3.0.2" description = "Python helper for Semantic Versioning (https://semver.org)" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"}, @@ -6062,7 +6062,7 @@ files = [ name = "sentry-sdk" version = "2.19.0" description = "Python client for Sentry (https://sentry.io)" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "sentry_sdk-2.19.0-py2.py3-none-any.whl", hash = "sha256:7b0b3b709dee051337244a09a30dbf6e95afe0d34a1f8b430d45e0982a7c125b"}, @@ -6699,7 +6699,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. name = "types-requests" version = "2.32.0.20241016" description = "Typing stubs for requests" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, @@ -7595,4 +7595,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "28cd26c6573ca0a07173262bc0e819e19b661157fa757efca0590262f9b9f35c" +content-hash = "9e0f7eb7ed1007cfeb0227d0f1bf20c5601c23e1d363ee23195ce6f8e134f14e" diff --git a/pyproject.toml b/pyproject.toml index 9a07491227..7b20a73840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,8 +66,8 @@ llama-index = "^0.11.9" llama-index-embeddings-openai = "^0.2.5" llama-index-embeddings-ollama = "^0.3.1" wikipedia = {version = "^1.4.0", optional = true} -composio-langchain = {version = "^0.5.28", optional = true} -composio-core = {version = "^0.5.34", optional = true} +composio-langchain = "^0.5.28" +composio-core = "^0.5.34" alembic = "^1.13.3" pyhumps = "^3.8.0" psycopg2 = "^2.9.10" diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index c5f1c15665..4269fdd8c6 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -1,6 +1,9 @@ from typing import Union from letta import LocalClient, RESTClient +from letta.functions.functions import parse_source_code +from letta.functions.schema_generator import generate_schema +from letta.schemas.tool import Tool def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): @@ -9,3 +12,15 @@ def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): if agent_state.name == agent_uuid: client.delete_agent(agent_id=agent_state.id) print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") + + +# Utility functions +def create_tool_from_func(func: callable): + return Tool( + name=func.__name__, + description="", + source_type="python", + tags=[], + source_code=parse_source_code(func), + json_schema=generate_schema(func, None), + ) diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 1d5f556e62..1c5dd05f1a 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -10,8 +10,6 @@ from letta import create_client from letta.functions.function_sets.base import core_memory_replace -from letta.functions.functions import parse_source_code -from letta.functions.schema_generator import generate_schema from letta.orm import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -34,6 +32,7 @@ from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import tool_settings +from tests.helpers.utils import create_tool_from_func # Constants namespace = uuid.NAMESPACE_DNS @@ -214,18 +213,6 @@ def agent_state(): yield agent_state -# Utility functions -def create_tool_from_func(func: callable): - return Tool( - name=func.__name__, - description="", - source_type="python", - tags=[], - source_code=parse_source_code(func), - json_schema=generate_schema(func, None), - ) - - # Local sandbox tests @pytest.mark.local_sandbox def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_user): diff --git a/tests/test_server.py b/tests/test_server.py index 94a799869e..ab36f5553a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -26,7 +26,6 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message -from letta.schemas.memory import ChatMemory from letta.schemas.source import Source from letta.server.server import SyncServer @@ -540,3 +539,15 @@ def _test_get_messages_letta_format( def test_get_messages_letta_format(server, user_id, agent_id): for reverse in [False, True]: _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse) + + +def test_composio_client_simple(server): + apps = server.get_composio_apps() + # Assert there's some amount of apps returned + assert len(apps) > 0 + + app = apps[0] + actions = server.get_composio_actions_from_app_name(composio_app_name=app.name) + + # Assert there's some amount of actions + assert len(actions) > 0 diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py new file mode 100644 index 0000000000..883395e1bd --- /dev/null +++ b/tests/test_v1_routes.py @@ -0,0 +1,317 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from composio.client.collections import ( + ActionModel, + ActionParametersModel, + ActionResponseModel, + AppModel, +) +from fastapi.testclient import TestClient + +from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.server.rest_api.app import app +from letta.server.rest_api.utils import get_letta_server +from tests.helpers.utils import create_tool_from_func + + +@pytest.fixture +def client(): + return TestClient(app) + + +@pytest.fixture +def mock_sync_server(): + mock_server = Mock() + app.dependency_overrides[get_letta_server] = lambda: mock_server + return mock_server + + +@pytest.fixture +def add_integers_tool(): + def add(x: int, y: int) -> int: + """ + Simple function that adds two integers. + + Parameters: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The result of adding x and y. + """ + return x + y + + tool = create_tool_from_func(add) + yield tool + + +@pytest.fixture +def create_integers_tool(add_integers_tool): + tool_create = ToolCreate( + name=add_integers_tool.name, + description=add_integers_tool.description, + tags=add_integers_tool.tags, + module=add_integers_tool.module, + source_code=add_integers_tool.source_code, + source_type=add_integers_tool.source_type, + json_schema=add_integers_tool.json_schema, + ) + yield tool_create + + +@pytest.fixture +def update_integers_tool(add_integers_tool): + tool_update = ToolUpdate( + name=add_integers_tool.name, + description=add_integers_tool.description, + tags=add_integers_tool.tags, + module=add_integers_tool.module, + source_code=add_integers_tool.source_code, + source_type=add_integers_tool.source_type, + json_schema=add_integers_tool.json_schema, + ) + yield tool_update + + +@pytest.fixture +def composio_apps(): + affinity_app = AppModel( + name="affinity", + key="affinity", + appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8", + description="Affinity helps private capital investors to find, manage, and close more deals", + categories=["CRM"], + meta={ + "is_custom_app": False, + "triggersCount": 0, + "actionsCount": 20, + "documentation_doc_text": None, + "configuration_docs_text": None, + }, + logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg", + docs=None, + group=None, + status=None, + enabled=False, + no_auth=False, + auth_schemes=None, + testConnectors=None, + documentation_doc_text=None, + configuration_docs_text=None, + ) + yield [affinity_app] + + +@pytest.fixture +def composio_actions(): + yield [ + ActionModel( + name="AFFINITY_GET_ALL_COMPANIES", + display_name="Get all companies", + parameters=ActionParametersModel( + properties={ + "cursor": {"default": None, "description": "Cursor for the next or previous page", "title": "Cursor", "type": "string"}, + "limit": {"default": 100, "description": "Number of items to include in the page", "title": "Limit", "type": "integer"}, + "ids": {"default": None, "description": "Company IDs", "items": {"type": "integer"}, "title": "Ids", "type": "array"}, + "fieldIds": { + "default": None, + "description": "Field IDs for which to return field data", + "items": {"type": "string"}, + "title": "Fieldids", + "type": "array", + }, + "fieldTypes": { + "default": None, + "description": "Field Types for which to return field data", + "items": {"enum": ["enriched", "global", "relationship-intelligence"], "title": "FieldtypesEnm", "type": "string"}, + "title": "Fieldtypes", + "type": "array", + }, + }, + title="GetAllCompaniesRequest", + type="object", + required=None, + ), + response=ActionResponseModel( + properties={ + "data": {"title": "Data", "type": "object"}, + "successful": { + "description": "Whether or not the action execution was successful or not", + "title": "Successful", + "type": "boolean", + }, + "error": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "Error if any occurred during the execution of the action", + "title": "Error", + }, + }, + title="GetAllCompaniesResponse", + type="object", + required=["data", "successful"], + ), + appName="affinity", + appId="affinity", + tags=["companies", "important"], + enabled=False, + logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg", + description="Affinity Api Allows Paginated Access To Company Info And Custom Fields. Use `Field Ids` Or `Field Types` To Specify Data In A Request. Retrieve Field I Ds/Types Via Get `/V2/Companies/Fields`. Export Permission Needed.", + ) + ] + + +# ====================================================================================================================== +# Tools Routes Tests +# ====================================================================================================================== +def test_delete_tool(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.delete_tool_by_id = MagicMock() + + response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) + + assert response.status_code == 200 + mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with( + tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value + ) + + +def test_get_tool(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool + + response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert response.json()["id"] == add_integers_tool.id + assert response.json()["source_code"] == add_integers_tool.source_code + mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with( + tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value + ) + + +def test_get_tool_404(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.get_tool_by_id.return_value = None + + response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) + + assert response.status_code == 404 + assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found." + + +def test_get_tool_id(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.get_tool_by_name.return_value = add_integers_tool + + response = client.get(f"/v1/tools/name/{add_integers_tool.name}", headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert response.json() == add_integers_tool.id + mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with( + tool_name=add_integers_tool.name, actor=mock_sync_server.get_user_or_default.return_value + ) + + +def test_get_tool_id_404(client, mock_sync_server): + mock_sync_server.tool_manager.get_tool_by_name.return_value = None + + response = client.get("/v1/tools/name/UnknownTool", headers={"user_id": "test_user"}) + + assert response.status_code == 404 + assert "Tool with name UnknownTool" in response.json()["detail"] + + +def test_list_tools(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool] + + response = client.get("/v1/tools", headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["id"] == add_integers_tool.id + mock_sync_server.tool_manager.list_tools.assert_called_once() + + +def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): + mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool + + response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert response.json()["id"] == add_integers_tool.id + mock_sync_server.tool_manager.create_tool.assert_called_once() + + +def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): + mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool + + response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert response.json()["id"] == add_integers_tool.id + mock_sync_server.tool_manager.create_or_update_tool.assert_called_once() + + +def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool): + mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool + + response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert response.json()["id"] == add_integers_tool.id + mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with( + tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.get_user_or_default.return_value + ) + + +def test_add_base_tools(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool] + + response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"}) + + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["id"] == add_integers_tool.id + mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(actor=mock_sync_server.get_user_or_default.return_value) + + +def test_list_composio_apps(client, mock_sync_server, composio_apps): + mock_sync_server.get_composio_apps.return_value = composio_apps + + response = client.get("/v1/tools/composio/apps") + + assert response.status_code == 200 + assert len(response.json()) == 1 + mock_sync_server.get_composio_apps.assert_called_once() + + +def test_list_composio_actions_by_app(client, mock_sync_server, composio_actions): + mock_sync_server.get_composio_actions_from_app_name.return_value = composio_actions + + response = client.get("/v1/tools/composio/apps/App1/actions") + + assert response.status_code == 200 + assert len(response.json()) == 1 + mock_sync_server.get_composio_actions_from_app_name.assert_called_once_with(composio_app_name="App1") + + +def test_add_composio_tool(client, mock_sync_server, add_integers_tool): + # Mock ToolCreate.from_composio to return the expected ToolCreate object + with patch("letta.schemas.tool.ToolCreate.from_composio") as mock_from_composio: + mock_from_composio.return_value = ToolCreate( + name=add_integers_tool.name, + source_code=add_integers_tool.source_code, + json_schema=add_integers_tool.json_schema, + ) + + # Mock server behavior + mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool + + # Perform the request + response = client.post(f"/v1/tools/composio/{add_integers_tool.name}", headers={"user_id": "test_user"}) + + # Assertions + assert response.status_code == 200 + assert response.json()["id"] == add_integers_tool.id + mock_sync_server.tool_manager.create_or_update_tool.assert_called_once() + + # Verify the mocked from_composio method was called + mock_from_composio.assert_called_once_with(action=add_integers_tool.name)