Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move tool functions to user #1487

Merged
merged 6 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions memgpt/client/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.admin.tools import (
CreateToolRequest,
ListToolsResponse,
ToolModel,
)
from memgpt.server.rest_api.admin.users import (
CreateAPIKeyResponse,
CreateUserResponse,
Expand All @@ -15,7 +19,6 @@
GetAllUsersResponse,
GetAPIKeysResponse,
)
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse


class Admin:
Expand Down Expand Up @@ -86,6 +89,7 @@ def _reset_server(self):
self.delete_key(key)
self.delete_user(user["user_id"])

# tools
def create_tool(
self,
func,
Expand All @@ -94,12 +98,10 @@ def create_tool(
tags: Optional[List[str]] = None,
):
"""Create a tool

Args:
func (callable): The function to create a tool for.
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.

Returns:
Tool object
"""
Expand All @@ -110,11 +112,11 @@ def create_tool(
source_code = parse_source_code(func)
json_schema = generate_schema(func, name)
source_type = "python"
tool_name = json_schema["name"]
json_schema["name"]

# create data
data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
CreateToolRequest(**data) # validate data:w
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
CreateToolRequest(**data) # validate

# make REST request
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)
Expand Down
86 changes: 77 additions & 9 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
SourceModel,
ToolModel,
)

# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.agents.command import CommandResponse
from memgpt.server.rest_api.agents.config import GetAgentResponse
from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse
Expand All @@ -54,6 +52,9 @@
ListPresetsResponse,
)
from memgpt.server.rest_api.sources.index import ListSourcesResponse

# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer


Expand Down Expand Up @@ -235,8 +236,6 @@ def __init__(
self.base_url = base_url
self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}

# agents

def list_agents(self):
response = requests.get(f"{self.base_url}/api/agents", headers=self.headers)
return ListAgentsResponse(**response.json())
Expand Down Expand Up @@ -610,6 +609,67 @@ def get_config(self) -> ConfigResponse:
response = requests.get(f"{self.base_url}/api/config", headers=self.headers)
return ConfigResponse(**response.json())

# tools

def create_tool(
self,
func,
name: Optional[str] = None,
update: Optional[bool] = True, # TODO: actually use this
tags: Optional[List[str]] = None,
):
"""Create a tool

Args:
func (callable): The function to create a tool for.
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.

Returns:
Tool object
"""

# TODO: check if tool already exists
# TODO: how to load modules?
# parse source code/schema
source_code = parse_source_code(func)
json_schema = generate_schema(func, name)
source_type = "python"
json_schema["name"]

# create data
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
try:
CreateToolRequest(**data) # validate data
except Exception as e:
raise ValueError(f"Failed to create tool: {e}, invalid input {data}")

# make REST request
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json())

def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
return ListToolsResponse(**response.json()).tools

def delete_tool(self, name: str):
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete tool: {response.text}")
return response.json()

def get_tool(self, name: str):
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return ToolModel(**response.json())


class LocalClient(AbstractClient):
def __init__(
Expand Down Expand Up @@ -820,7 +880,7 @@ def create_tool(
tool_name = json_schema["name"]

# check if already exists:
existing_tool = self.server.ms.get_tool(tool_name)
existing_tool = self.server.ms.get_tool(tool_name, self.user_id)
if existing_tool:
if update:
# update existing tool
Expand All @@ -829,13 +889,15 @@ def create_tool(
existing_tool.tags = tags
existing_tool.json_schema = json_schema
self.server.ms.update_tool(existing_tool)
return self.server.ms.get_tool(tool_name)
return self.server.ms.get_tool(tool_name, self.user_id)
else:
raise ValueError(f"Tool {name} already exists and update=False")

tool = ToolModel(name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema)
tool = ToolModel(
name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=self.user_id
)
self.server.ms.add_tool(tool)
return self.server.ms.get_tool(tool_name)
return self.server.ms.get_tool(tool_name, self.user_id)

def list_tools(self):
"""List available tools.
Expand All @@ -844,7 +906,13 @@ def list_tools(self):
tools (List[ToolModel]): A list of available tools.

"""
return self.server.ms.list_tools()
return self.server.ms.list_tools(user_id=self.user_id)

def get_tool(self, name: str):
return self.server.ms.get_tool(name, user_id=self.user_id)

def delete_tool(self, name: str):
return self.server.ms.delete_tool(name, user_id=self.user_id)

# data sources

Expand Down
19 changes: 13 additions & 6 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,11 @@ def list_presets(self, user_id: uuid.UUID) -> List[Preset]:

@enforce_types
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self) -> List[ToolModel]:
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).all()
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
return results

@enforce_types
Expand Down Expand Up @@ -677,10 +679,13 @@ def get_source(
return results[0].to_record()

@enforce_types
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
# TODO: add user_id when tools can eventually be added by users
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()

if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
Expand Down Expand Up @@ -752,6 +757,8 @@ def add_preset(self, preset: PresetModel):
@enforce_types
def add_tool(self, tool: ToolModel):
with self.session_maker() as session:
if self.get_tool(tool.name, tool.user_id):
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
session.add(tool)
session.commit()

Expand Down Expand Up @@ -811,9 +818,9 @@ def delete_preset(self, name: str, user_id: uuid.UUID):
session.commit()

@enforce_types
def delete_tool(self, name: str):
def delete_tool(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.name == name).delete()
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
session.commit()

# job related functions
Expand Down
7 changes: 5 additions & 2 deletions memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@ class PresetModel(BaseModel):

class ToolModel(SQLModel, table=True):
# TODO move into database
name: str = Field(..., description="The name of the function.", primary_key=True)
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.")
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")
module: Optional[str] = Field(None, description="The module of the function.")

json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")

# optional: user_id (user-specific tools)
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.")

# Needed for Column(JSON)
class Config:
arbitrary_types_allowed = True
Expand Down
11 changes: 4 additions & 7 deletions memgpt/server/rest_api/admin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def delete_tool(
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name)
server.ms.delete_tool(name=tool_name, user_id=None)

@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(tool_name: str):
Expand All @@ -49,29 +49,26 @@ async def get_tool(tool_name: str):
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name)
tool = server.ms.get_tool(tool_name=tool_name, user_id=None)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool

@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
async def list_all_tools():
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
tools = server.ms.list_tools(user_id=None)
return ListToolsResponse(tools=tools)

@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Create a new tool
Expand Down
46 changes: 38 additions & 8 deletions memgpt/server/rest_api/tools/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from typing import List, Literal, Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field

from memgpt.models.pydantic_models import ToolModel
Expand All @@ -18,7 +18,7 @@ class ListToolsResponse(BaseModel):


class CreateToolRequest(BaseModel):
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="JSON schema of the tool.")
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
Expand All @@ -31,34 +31,64 @@ class CreateToolResponse(BaseModel):
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
tool_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name, user_id=user_id)

@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(
tool_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name)
tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool

@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)

@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Create a new tool
"""
try:
return server.create_tool(
json_schema=request.json_schema,
source_code=request.source_code,
source_type=request.source_type,
tags=request.tags,
user_id=user_id,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")

return router
Loading
Loading