Skip to content

Commit

Permalink
feat: Add data sources to REST API (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Mar 10, 2024
1 parent 2b7d3e8 commit 59f1e11
Show file tree
Hide file tree
Showing 13 changed files with 422 additions and 54 deletions.
3 changes: 1 addition & 2 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,11 @@ def load_directory(
document_store=None,
passage_store=passage_storage,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
ms.delete_source(source_id=source.id)

print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")

except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
raise
Expand Down
96 changes: 90 additions & 6 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import uuid
from typing import Dict, List, Union, Optional, Tuple

from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.metadata import MetadataStore
from memgpt.data_sources.connectors import DataConnector


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -69,6 +70,30 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]:
def save(self):
raise NotImplementedError

def list_sources(self):
"""List loaded sources"""
raise NotImplementedError

def delete_source(self):
"""Delete a source and associated data (including attached to agents)"""
raise NotImplementedError

def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
raise NotImplementedError

def create_source(self, name: str):
"""Create a new source"""
raise NotImplementedError

def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Attach a source to an agent"""
raise NotImplementedError

def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
raise NotImplementedError


class RESTClient(AbstractClient):
def __init__(
Expand Down Expand Up @@ -127,17 +152,17 @@ def create_agent(
)
return agent_state

def delete_agent(self, agent_id: str):
response = requests.delete(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
return agent_id
def delete_agent(self, agent_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete agent: {response.text}"

def create_preset(self, preset: Preset):
raise NotImplementedError

def get_agent_config(self, agent_id: str) -> AgentState:
def get_agent_config(self, agent_id: uuid.UUID) -> AgentState:
raise NotImplementedError

def get_agent_memory(self, agent_id: str) -> Dict:
def get_agent_memory(self, agent_id: uuid.UUID) -> Dict:
raise NotImplementedError

def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict:
Expand All @@ -157,6 +182,53 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]:
def save(self):
raise NotImplementedError

def list_sources(self):
"""List loaded sources"""
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
response_json = response.json()
return response_json

def delete_source(self, source_id: uuid.UUID):
"""Delete a source and associated data (including attached to agents)"""
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete source: {response.text}"

def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
params = {"source_id": str(source_id)}
files = {"file": open(filename, "rb")}
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
return response.json()

def create_source(self, name: str) -> Source:
"""Create a new source"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
response_json = response.json()
print("CREATE SOURCE", response_json, response.text)
return Source(
id=uuid.UUID(response_json["id"]),
name=response_json["name"],
user_id=uuid.UUID(response_json["user_id"]),
created_at=datetime.datetime.fromtimestamp(response_json["created_at"]),
embedding_dim=response_json["embedding_config"]["embedding_dim"],
embedding_model=response_json["embedding_config"]["embedding_model"],
)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
"""Attach a source to an agent"""
params = {"source_name": source_name, "agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return response.json()

def detach_source(self, source_name: str, agent_id: uuid.UUID):
"""Detach a source from an agent"""
params = {"source_name": source_name, "agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()


class LocalClient(AbstractClient):
def __init__(
Expand Down Expand Up @@ -267,3 +339,15 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]:

def save(self):
self.server.save_agents()

def load_data(self, connector: DataConnector, source_name: str):
self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name)

def create_source(self, name: str):
self.server.create_source(user_id=self.user_id, name=name)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)

def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
11 changes: 10 additions & 1 deletion memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from memgpt.data_types import Document, Passage

from typing import List, Iterator, Dict, Tuple, Optional
import typer
from llama_index.core import Document as LlamaIndexDocument


Expand Down Expand Up @@ -53,7 +54,15 @@ def load_data(

# generate passages
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
embedding = embed_model.get_text_embedding(passage_text)
try:
embedding = embed_model.get_text_embedding(passage_text)
except Exception as e:
typer.secho(
f"Warning: Failed to get embedding for {passage_text} (error: {str(e)}), skipping insert into VectorDB.",
fg=typer.colors.YELLOW,
)
continue

passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text,
Expand Down
38 changes: 38 additions & 0 deletions memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json, ConfigDict
import uuid
import base64
import numpy as np
from datetime import datetime
from sqlmodel import Field, SQLModel
from sqlalchemy import JSON, Column, BINARY, TypeDecorator

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text, printd
Expand Down Expand Up @@ -83,3 +86,38 @@ class PersonaModel(SQLModel, table=True):
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")


class SourceModel(SQLModel, table=True):
name: str = Field(..., description="The name of the source.")
description: str = Field(None, description="The description of the source.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
# embedding info
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)


class PassageModel(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
text: str = Field(..., description="The text of the passage.")
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
data_source: Optional[str] = Field(None, description="The data source of the passage.")
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")


class DocumentModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
text: str = Field(..., description="The text of the document.")
data_source: str = Field(..., description="The data source of the document.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
6 changes: 2 additions & 4 deletions memgpt/server/rest_api/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ def create_user(request: Optional[CreateUserRequest] = Body(None)):
raise HTTPException(status_code=500, detail=f"{e}")
return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token)

@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(
user_id: str = Query(..., description="The ID of the user to be deleted."),
):
@router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(user_id):
# TODO make a soft deletion, instead of a hard deletion
try:
user_id_uuid = uuid.UUID(user_id)
Expand Down
24 changes: 12 additions & 12 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import re
import uuid
from functools import partial
from typing import List, Optional

from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional

from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
Expand All @@ -20,6 +20,7 @@ class GetAgentRequest(BaseModel):


class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")


Expand Down Expand Up @@ -50,9 +51,9 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: uuid.UUID,
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Expand All @@ -68,8 +69,8 @@ def get_agent_config(
interface.clear()
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# return GetAgentResponse(agent_state=agent_state)
LLMConfigModel(**vars(agent_state.llm_config))
EmbeddingConfigModel(**vars(agent_state.embedding_config))
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

return GetAgentResponse(
agent_state=AgentStateModel(
Expand All @@ -89,9 +90,8 @@ def get_agent_config(
sources=attached_sources,
)

@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
Expand All @@ -100,6 +100,8 @@ def update_agent_name(
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

valid_name = validate_agent_name(request.agent_name)

interface.clear()
Expand All @@ -113,15 +115,13 @@ def update_agent_name(

@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: uuid.UUID,
agent_id,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
request = GetAgentRequest(agent_id=agent_id)

agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
agent_id = uuid.UUID(agent_id)

interface.clear()
try:
Expand Down
2 changes: 2 additions & 0 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.server import SyncServer

"""
Expand Down Expand Up @@ -94,6 +95,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)

# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
Expand Down
Empty file.
Loading

0 comments on commit 59f1e11

Please sign in to comment.