diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index 9072eb7c1a..b914c7f7f7 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -6,7 +6,11 @@ from memgpt.functions.functions import parse_source_code from memgpt.functions.schema_generator import generate_schema -from memgpt.models.pydantic_models import ToolModel +from memgpt.server.rest_api.admin.tools import ( + CreateToolRequest, + ListToolsResponse, + ToolModel, +) from memgpt.server.rest_api.admin.users import ( CreateAPIKeyResponse, CreateUserResponse, @@ -15,7 +19,6 @@ GetAllUsersResponse, GetAPIKeysResponse, ) -from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse class Admin: @@ -86,6 +89,7 @@ def _reset_server(self): self.delete_key(key) self.delete_user(user["user_id"]) + # tools def create_tool( self, func, @@ -94,12 +98,10 @@ def create_tool( tags: Optional[List[str]] = None, ): """Create a tool - Args: func (callable): The function to create a tool for. tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. update (bool, optional): Update the tool if it already exists. Defaults to True. - Returns: Tool object """ @@ -110,11 +112,11 @@ def create_tool( source_code = parse_source_code(func) json_schema = generate_schema(func, name) source_type = "python" - tool_name = json_schema["name"] + json_schema["name"] # create data - data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} - CreateToolRequest(**data) # validate data:w + data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} + CreateToolRequest(**data) # validate # make REST request response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index ec86471f28..047ab7daea 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -28,8 +28,6 @@ SourceModel, ToolModel, ) - -# import pydantic response objects from memgpt.server.rest_api from memgpt.server.rest_api.agents.command import CommandResponse from memgpt.server.rest_api.agents.config import GetAgentResponse from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse @@ -54,6 +52,9 @@ ListPresetsResponse, ) from memgpt.server.rest_api.sources.index import ListSourcesResponse + +# import pydantic response objects from memgpt.server.rest_api +from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse from memgpt.server.server import SyncServer @@ -235,8 +236,6 @@ def __init__( self.base_url = base_url self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} - # agents - def list_agents(self): response = requests.get(f"{self.base_url}/api/agents", headers=self.headers) return ListAgentsResponse(**response.json()) @@ -610,6 +609,67 @@ def get_config(self) -> ConfigResponse: response = requests.get(f"{self.base_url}/api/config", headers=self.headers) return ConfigResponse(**response.json()) + # tools + + def create_tool( + self, + func, + name: Optional[str] = None, + update: Optional[bool] = True, # TODO: actually use this + tags: Optional[List[str]] = None, + ): + """Create a tool + + Args: + func (callable): The function to create a tool for. + tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. + update (bool, optional): Update the tool if it already exists. Defaults to True. + + Returns: + Tool object + """ + + # TODO: check if tool already exists + # TODO: how to load modules? + # parse source code/schema + source_code = parse_source_code(func) + json_schema = generate_schema(func, name) + source_type = "python" + json_schema["name"] + + # create data + data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} + try: + CreateToolRequest(**data) # validate data + except Exception as e: + raise ValueError(f"Failed to create tool: {e}, invalid input {data}") + + # make REST request + response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to create tool: {response.text}") + return ToolModel(**response.json()) + + def list_tools(self) -> ListToolsResponse: + response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to list tools: {response.text}") + return ListToolsResponse(**response.json()).tools + + def delete_tool(self, name: str): + response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to delete tool: {response.text}") + return response.json() + + def get_tool(self, name: str): + response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers) + if response.status_code == 404: + return None + elif response.status_code != 200: + raise ValueError(f"Failed to get tool: {response.text}") + return ToolModel(**response.json()) + class LocalClient(AbstractClient): def __init__( @@ -820,7 +880,7 @@ def create_tool( tool_name = json_schema["name"] # check if already exists: - existing_tool = self.server.ms.get_tool(tool_name) + existing_tool = self.server.ms.get_tool(tool_name, self.user_id) if existing_tool: if update: # update existing tool @@ -829,13 +889,15 @@ def create_tool( existing_tool.tags = tags existing_tool.json_schema = json_schema self.server.ms.update_tool(existing_tool) - return self.server.ms.get_tool(tool_name) + return self.server.ms.get_tool(tool_name, self.user_id) else: raise ValueError(f"Tool {name} already exists and update=False") - tool = ToolModel(name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema) + tool = ToolModel( + name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=self.user_id + ) self.server.ms.add_tool(tool) - return self.server.ms.get_tool(tool_name) + return self.server.ms.get_tool(tool_name, self.user_id) def list_tools(self): """List available tools. @@ -844,7 +906,13 @@ def list_tools(self): tools (List[ToolModel]): A list of available tools. """ - return self.server.ms.list_tools() + return self.server.ms.list_tools(user_id=self.user_id) + + def get_tool(self, name: str): + return self.server.ms.get_tool(name, user_id=self.user_id) + + def delete_tool(self, name: str): + return self.server.ms.delete_tool(name, user_id=self.user_id) # data sources diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 62a67425d9..29491df7ad 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -604,9 +604,11 @@ def list_presets(self, user_id: uuid.UUID) -> List[Preset]: @enforce_types # def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools - def list_tools(self) -> List[ToolModel]: + def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]: with self.session_maker() as session: - results = session.query(ToolModel).all() + results = session.query(ToolModel).filter(ToolModel.user_id == None).all() + if user_id: + results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all() return results @enforce_types @@ -677,10 +679,13 @@ def get_source( return results[0].to_record() @enforce_types - def get_tool(self, tool_name: str) -> Optional[ToolModel]: + def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]: # TODO: add user_id when tools can eventually be added by users with self.session_maker() as session: - results = session.query(ToolModel).filter(ToolModel.name == tool_name).all() + results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all() + if user_id: + results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() + if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" @@ -752,6 +757,8 @@ def add_preset(self, preset: PresetModel): @enforce_types def add_tool(self, tool: ToolModel): with self.session_maker() as session: + if self.get_tool(tool.name, tool.user_id): + raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}") session.add(tool) session.commit() @@ -811,9 +818,9 @@ def delete_preset(self, name: str, user_id: uuid.UUID): session.commit() @enforce_types - def delete_tool(self, name: str): + def delete_tool(self, name: str, user_id: uuid.UUID): with self.session_maker() as session: - session.query(ToolModel).filter(ToolModel.name == name).delete() + session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete() session.commit() # job related functions diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 74558c3809..2e3ec8b69a 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -56,8 +56,8 @@ class PresetModel(BaseModel): class ToolModel(SQLModel, table=True): # TODO move into database - name: str = Field(..., description="The name of the function.", primary_key=True) - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.") + name: str = Field(..., description="The name of the function.") + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True) tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.") source_type: Optional[str] = Field(None, description="The type of the source code.") source_code: Optional[str] = Field(..., description="The source code of the function.") @@ -65,6 +65,9 @@ class ToolModel(SQLModel, table=True): json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.") + # optional: user_id (user-specific tools) + user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.") + # Needed for Column(JSON) class Config: arbitrary_types_allowed = True diff --git a/memgpt/server/rest_api/admin/tools.py b/memgpt/server/rest_api/admin/tools.py index f3f5dd9f68..a20a21f516 100644 --- a/memgpt/server/rest_api/admin/tools.py +++ b/memgpt/server/rest_api/admin/tools.py @@ -39,7 +39,7 @@ async def delete_tool( # Clear the interface interface.clear() # tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific - server.ms.delete_tool(name=tool_name) + server.ms.delete_tool(name=tool_name, user_id=None) @router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel) async def get_tool(tool_name: str): @@ -49,29 +49,26 @@ async def get_tool(tool_name: str): # Clear the interface interface.clear() # tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific - tool = server.ms.get_tool(tool_name=tool_name) + tool = server.ms.get_tool(tool_name=tool_name, user_id=None) if tool is None: # return 404 error raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") return tool @router.get("/tools", tags=["tools"], response_model=ListToolsResponse) - async def list_all_tools( - # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific - ): + async def list_all_tools(): """ Get a list of all tools available to agents created by a user """ # Clear the interface interface.clear() # tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific - tools = server.ms.list_tools() + tools = server.ms.list_tools(user_id=None) return ListToolsResponse(tools=tools) @router.post("/tools", tags=["tools"], response_model=ToolModel) async def create_tool( request: CreateToolRequest = Body(...), - # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific ): """ Create a new tool diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py index 7070502773..b564dc7263 100644 --- a/memgpt/server/rest_api/tools/index.py +++ b/memgpt/server/rest_api/tools/index.py @@ -2,7 +2,7 @@ from functools import partial from typing import List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel, Field from memgpt.models.pydantic_models import ToolModel @@ -18,7 +18,7 @@ class ListToolsResponse(BaseModel): class CreateToolRequest(BaseModel): - name: str = Field(..., description="The name of the function.") + json_schema: dict = Field(..., description="JSON schema of the tool.") source_code: str = Field(..., description="The source code of the function.") source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.") tags: Optional[List[str]] = Field(None, description="Metadata tags.") @@ -31,18 +31,30 @@ class CreateToolResponse(BaseModel): def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) + @router.delete("/tools/{tool_name}", tags=["tools"]) + async def delete_tool( + tool_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """ + Delete a tool by name + """ + # Clear the interface + interface.clear() + # tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific + server.ms.delete_tool(name=tool_name, user_id=user_id) + @router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel) async def get_tool( tool_name: str, - user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific + user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ Get a tool by name """ # Clear the interface interface.clear() - # tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific - tool = server.ms.get_tool(tool_name=tool_name) + tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id) if tool is None: # return 404 error raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") @@ -50,15 +62,33 @@ async def get_tool( @router.get("/tools", tags=["tools"], response_model=ListToolsResponse) async def list_all_tools( - user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific + user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ Get a list of all tools available to agents created by a user """ # Clear the interface interface.clear() - # tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific - tools = server.ms.list_tools() + tools = server.ms.list_tools(user_id=user_id) return ListToolsResponse(tools=tools) + @router.post("/tools", tags=["tools"], response_model=ToolModel) + async def create_tool( + request: CreateToolRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """ + Create a new tool + """ + try: + return server.create_tool( + json_schema=request.json_schema, + source_code=request.source_code, + source_type=request.source_type, + tags=request.tags, + user_id=user_id, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}") + return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 09a95076bc..49a6c1b291 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -336,7 +336,7 @@ def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[ # Instantiate an agent object using the state retrieved logger.info(f"Creating an agent object") - tool_objs = [self.ms.get_tool(name) for name in agent_state.tools] # get tool objects + tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) # Add the agent to the in-memory store and return its reference @@ -763,7 +763,7 @@ def create_agent( # get tools tool_objs = [] for tool_name in tools: - tool_obj = self.ms.get_tool(tool_name) + tool_obj = self.ms.get_tool(tool_name, user_id=user_id) assert tool_obj is not None, f"Tool {tool_name} does not exist" tool_objs.append(tool_obj) @@ -1487,7 +1487,13 @@ def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]: return sources_with_metadata def create_tool( - self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True + self, + json_schema: dict, + source_code: str, + source_type: str, + tags: Optional[List[str]] = None, + exists_ok: Optional[bool] = True, + user_id: Optional[uuid.UUID] = None, ) -> ToolModel: # TODO: add other fields """Create a new tool @@ -1511,10 +1517,12 @@ def create_tool( raise ValueError(f"Tool with name {name} already exists.") else: # create new tool - tool = ToolModel(name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type) + tool = ToolModel( + name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type, user_id=user_id + ) self.ms.add_tool(tool) - return self.ms.get_tool(name) + return self.ms.get_tool(name, user_id=user_id) def delete_tool(self, name: str): """Delete a tool""" diff --git a/tests/test_tools.py b/tests/test_tools.py index e4d02f063b..9d03ab88dc 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -68,11 +68,12 @@ def run_server(): # Fixture to create clients with different configurations @pytest.fixture( - # params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented - params=[{"server": True}], # whether to use REST API server # TODO: add when implemented + params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented + # params=[{"server": True}], # whether to use REST API server # TODO: add when implemented scope="module", ) -def client(request): +def admin_client(request): + if request.param["server"]: # get URL from enviornment server_url = os.getenv("MEMGPT_SERVER_URL") @@ -92,11 +93,19 @@ def client(request): admin._reset_server() else: - print("Testing local client") - # use local client (no server) - token = None - server_url = None - client = create_client(base_url=server_url, token=token) # This yields control back to the test function + yield None + + +@pytest.fixture(scope="module") +def client(admin_client): + if admin_client: + # create user via admin client + response = admin_client.create_user() + print("Created user", response.user_id, response.api_key) + client = create_client(base_url=admin_client.base_url, token=response.api_key) + yield client + else: + client = create_client() yield client @@ -124,6 +133,39 @@ def print_tool(message: str): assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}" print(f"Updated tools {[t.name for t in tools]}") + # check tool id + tool = client.get_tool(tool.name) + + +def test_create_agent_tool_admin(admin_client): + if admin_client is None: + return + + def print_tool(message: str): + """ + Args: + message (str): The message to print. + + Returns: + str: The message that was printed. + + """ + print(message) + return message + + tools = admin_client.list_tools() + print(f"Original tools {[t.name for t in tools]}") + + tool = admin_client.create_tool(print_tool, tags=["extras"]) + + tools = admin_client.list_tools() + assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}" + print(f"Updated tools {[t.name for t in tools]}") + + # check tool id + tool = admin_client.get_tool(tool.name) + assert tool.user_id is None, f"Expected {tool.user_id} to be None" + def test_create_agent_tool(client): """Test creation of a agent tool""" @@ -144,22 +186,15 @@ def core_memory_clear(self: Agent): return None # TODO: test attaching and using function on agent - tool = client.create_tool(core_memory_clear, tags=["extras"]) - - if isinstance(client, Admin): - # conver to user client type - response = client.create_user() - print("Created user", response.user_id, response.api_key) - user_client = create_client(base_url=client.base_url, token=response.api_key) - else: - user_client = client + tool = client.create_tool(core_memory_clear, tags=["extras"], update=True) + print(f"Created tool", tool.name) - agent = user_client.create_agent( - name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you" - ) + # create agent with tool + agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you") + assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}" # initial memory - initial_memory = user_client.get_agent_memory(agent.id) + initial_memory = client.get_agent_memory(agent.id) print("initial memory", initial_memory) human = initial_memory.core_memory.human persona = initial_memory.core_memory.persona @@ -168,11 +203,11 @@ def core_memory_clear(self: Agent): assert len(persona) > 0, "Expected persona memory to be non-empty" # test agent tool - response = user_client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool") + response = client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool") print(response) # updated memory - updated_memory = user_client.get_agent_memory(agent.id) + updated_memory = client.get_agent_memory(agent.id) human = updated_memory.core_memory.human persona = updated_memory.core_memory.persona print("Updated memory:", human, persona)