From 59f1e110f85462bd041a682fe99e721d3f9037d9 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sun, 10 Mar 2024 14:34:35 -0700 Subject: [PATCH] feat: Add data sources to REST API (#1118) --- memgpt/cli/cli_load.py | 3 +- memgpt/client/client.py | 96 +++++++++++- memgpt/data_sources/connectors.py | 11 +- memgpt/models/pydantic_models.py | 38 +++++ memgpt/server/rest_api/admin/users.py | 6 +- memgpt/server/rest_api/agents/config.py | 24 +-- memgpt/server/rest_api/server.py | 2 + memgpt/server/rest_api/sources/__init__.py | 0 memgpt/server/rest_api/sources/index.py | 165 +++++++++++++++++++++ memgpt/server/server.py | 22 ++- poetry.lock | 16 +- pyproject.toml | 1 + tests/test_client.py | 92 ++++++++---- 13 files changed, 422 insertions(+), 54 deletions(-) create mode 100644 memgpt/server/rest_api/sources/__init__.py create mode 100644 memgpt/server/rest_api/sources/index.py diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 7f7d8875f2..83c706334f 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -115,12 +115,11 @@ def load_directory( document_store=None, passage_store=passage_storage, ) + print(f"Loaded {num_passages} passages and {num_documents} documents from {name}") except Exception as e: typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) ms.delete_source(source_id=source.id) - print(f"Loaded {num_passages} passages and {num_documents} documents from {name}") - except ValueError as e: typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED) raise diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 8380a6fe94..517ee6b048 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -3,13 +3,14 @@ import uuid from typing import Dict, List, Union, Optional, Tuple -from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig +from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer from memgpt.metadata import MetadataStore +from memgpt.data_sources.connectors import DataConnector def create_client(base_url: Optional[str] = None, token: Optional[str] = None): @@ -69,6 +70,30 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]: def save(self): raise NotImplementedError + def list_sources(self): + """List loaded sources""" + raise NotImplementedError + + def delete_source(self): + """Delete a source and associated data (including attached to agents)""" + raise NotImplementedError + + def load_file_into_source(self, filename: str, source_id: uuid.UUID): + """Load {filename} and insert into source""" + raise NotImplementedError + + def create_source(self, name: str): + """Create a new source""" + raise NotImplementedError + + def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): + """Attach a source to an agent""" + raise NotImplementedError + + def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID): + """Detach a source from an agent""" + raise NotImplementedError + class RESTClient(AbstractClient): def __init__( @@ -127,17 +152,17 @@ def create_agent( ) return agent_state - def delete_agent(self, agent_id: str): - response = requests.delete(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers) - return agent_id + def delete_agent(self, agent_id: uuid.UUID): + response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers) + assert response.status_code == 200, f"Failed to delete agent: {response.text}" def create_preset(self, preset: Preset): raise NotImplementedError - def get_agent_config(self, agent_id: str) -> AgentState: + def get_agent_config(self, agent_id: uuid.UUID) -> AgentState: raise NotImplementedError - def get_agent_memory(self, agent_id: str) -> Dict: + def get_agent_memory(self, agent_id: uuid.UUID) -> Dict: raise NotImplementedError def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict: @@ -157,6 +182,53 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]: def save(self): raise NotImplementedError + def list_sources(self): + """List loaded sources""" + response = requests.get(f"{self.base_url}/api/sources", headers=self.headers) + response_json = response.json() + return response_json + + def delete_source(self, source_id: uuid.UUID): + """Delete a source and associated data (including attached to agents)""" + response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers) + assert response.status_code == 200, f"Failed to delete source: {response.text}" + + def load_file_into_source(self, filename: str, source_id: uuid.UUID): + """Load {filename} and insert into source""" + params = {"source_id": str(source_id)} + files = {"file": open(filename, "rb")} + response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers) + return response.json() + + def create_source(self, name: str) -> Source: + """Create a new source""" + payload = {"name": name} + response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers) + response_json = response.json() + print("CREATE SOURCE", response_json, response.text) + return Source( + id=uuid.UUID(response_json["id"]), + name=response_json["name"], + user_id=uuid.UUID(response_json["user_id"]), + created_at=datetime.datetime.fromtimestamp(response_json["created_at"]), + embedding_dim=response_json["embedding_config"]["embedding_dim"], + embedding_model=response_json["embedding_config"]["embedding_model"], + ) + + def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID): + """Attach a source to an agent""" + params = {"source_name": source_name, "agent_id": agent_id} + response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers) + assert response.status_code == 200, f"Failed to attach source to agent: {response.text}" + return response.json() + + def detach_source(self, source_name: str, agent_id: uuid.UUID): + """Detach a source from an agent""" + params = {"source_name": source_name, "agent_id": str(agent_id)} + response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers) + assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" + return response.json() + class LocalClient(AbstractClient): def __init__( @@ -267,3 +339,15 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]: def save(self): self.server.save_agents() + + def load_data(self, connector: DataConnector, source_name: str): + self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name) + + def create_source(self, name: str): + self.server.create_source(user_id=self.user_id, name=name) + + def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID): + self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id) + + def delete_agent(self, agent_id: uuid.UUID): + self.server.delete_agent(user_id=self.user_id, agent_id=agent_id) diff --git a/memgpt/data_sources/connectors.py b/memgpt/data_sources/connectors.py index 7d2b544679..ca5f46f16c 100644 --- a/memgpt/data_sources/connectors.py +++ b/memgpt/data_sources/connectors.py @@ -5,6 +5,7 @@ from memgpt.data_types import Document, Passage from typing import List, Iterator, Dict, Tuple, Optional +import typer from llama_index.core import Document as LlamaIndexDocument @@ -53,7 +54,15 @@ def load_data( # generate passages for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size): - embedding = embed_model.get_text_embedding(passage_text) + try: + embedding = embed_model.get_text_embedding(passage_text) + except Exception as e: + typer.secho( + f"Warning: Failed to get embedding for {passage_text} (error: {str(e)}), skipping insert into VectorDB.", + fg=typer.colors.YELLOW, + ) + continue + passage = Passage( id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"), text=passage_text, diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index c5f06987fa..11cf8fb9a4 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -1,8 +1,11 @@ from typing import List, Optional, Dict, Literal from pydantic import BaseModel, Field, Json, ConfigDict import uuid +import base64 +import numpy as np from datetime import datetime from sqlmodel import Field, SQLModel +from sqlalchemy import JSON, Column, BINARY, TypeDecorator from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM from memgpt.utils import get_human_text, get_persona_text, printd @@ -83,3 +86,38 @@ class PersonaModel(SQLModel, table=True): name: str = Field(..., description="The name of the persona.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True) user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.") + + +class SourceModel(SQLModel, table=True): + name: str = Field(..., description="The name of the source.") + description: str = Field(None, description="The description of the source.") + user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.") + created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the source was created.") + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True) + # embedding info + # embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.") + embedding_config: Optional[EmbeddingConfigModel] = Field( + None, sa_column=Column(JSON), description="The embedding configuration used by the passage." + ) + + +class PassageModel(BaseModel): + user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.") + agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.") + text: str = Field(..., description="The text of the passage.") + embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.") + embedding_config: Optional[EmbeddingConfigModel] = Field( + None, sa_column=Column(JSON), description="The embedding configuration used by the passage." + ) + data_source: Optional[str] = Field(None, description="The data source of the passage.") + doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.") + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True) + metadata: Optional[Dict] = Field({}, description="The metadata of the passage.") + + +class DocumentModel(BaseModel): + user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.") + text: str = Field(..., description="The text of the document.") + data_source: str = Field(..., description="The data source of the document.") + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True) + metadata: Optional[Dict] = Field({}, description="The metadata of the document.") diff --git a/memgpt/server/rest_api/admin/users.py b/memgpt/server/rest_api/admin/users.py index 1d84dc1783..0b7468efc4 100644 --- a/memgpt/server/rest_api/admin/users.py +++ b/memgpt/server/rest_api/admin/users.py @@ -100,10 +100,8 @@ def create_user(request: Optional[CreateUserRequest] = Body(None)): raise HTTPException(status_code=500, detail=f"{e}") return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token) - @router.delete("/users", tags=["admin"], response_model=DeleteUserResponse) - def delete_user( - user_id: str = Query(..., description="The ID of the user to be deleted."), - ): + @router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse) + def delete_user(user_id): # TODO make a soft deletion, instead of a hard deletion try: user_id_uuid = uuid.UUID(user_id) diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 05c5546646..069c2e97f6 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -1,11 +1,11 @@ import re import uuid from functools import partial -from typing import List, Optional -from fastapi import APIRouter, Body, Depends, HTTPException, status +from fastapi import APIRouter, Body, Depends, Query, HTTPException, status from fastapi.responses import JSONResponse from pydantic import BaseModel, Field +from typing import List, Optional from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel from memgpt.server.rest_api.auth_token import get_current_user @@ -20,6 +20,7 @@ class GetAgentRequest(BaseModel): class AgentRenameRequest(BaseModel): + agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") agent_name: str = Field(..., description="New name for the agent.") @@ -50,9 +51,9 @@ def validate_agent_name(name: str) -> str: def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse) + @router.get("/agents", tags=["agents"], response_model=GetAgentResponse) def get_agent_config( - agent_id: uuid.UUID, + agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."), user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ @@ -68,8 +69,8 @@ def get_agent_config( interface.clear() agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id) # return GetAgentResponse(agent_state=agent_state) - LLMConfigModel(**vars(agent_state.llm_config)) - EmbeddingConfigModel(**vars(agent_state.embedding_config)) + llm_config = LLMConfigModel(**vars(agent_state.llm_config)) + embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) return GetAgentResponse( agent_state=AgentStateModel( @@ -89,9 +90,8 @@ def get_agent_config( sources=attached_sources, ) - @router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse) + @router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse) def update_agent_name( - agent_id: uuid.UUID, request: AgentRenameRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -100,6 +100,8 @@ def update_agent_name( This changes the name of the agent in the database but does NOT edit the agent's persona. """ + agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + valid_name = validate_agent_name(request.agent_name) interface.clear() @@ -113,15 +115,13 @@ def update_agent_name( @router.delete("/agents/{agent_id}", tags=["agents"]) def delete_agent( - agent_id: uuid.UUID, + agent_id, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ Delete an agent. """ - request = GetAgentRequest(agent_id=agent_id) - - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + agent_id = uuid.UUID(agent_id) interface.clear() try: diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 6afb738f22..6a52ea11c3 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -21,6 +21,7 @@ from memgpt.server.rest_api.personas.index import setup_personas_index_router from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.rest_api.tools.index import setup_tools_index_router +from memgpt.server.rest_api.sources.index import setup_sources_index_router from memgpt.server.server import SyncServer """ @@ -94,6 +95,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX) +app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) # /api/config endpoints app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/rest_api/sources/__init__.py b/memgpt/server/rest_api/sources/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py new file mode 100644 index 0000000000..f49d7cafeb --- /dev/null +++ b/memgpt/server/rest_api/sources/index.py @@ -0,0 +1,165 @@ +import uuid +from functools import partial +from typing import List, Optional + +from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel +from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer +from memgpt.data_types import Source +from memgpt.data_sources.connectors import DirectoryConnector + +router = APIRouter() + +""" +Implement the following functions: +* List all available sources +* Create a new source +* Delete a source +* Upload a file to a server that is loaded into a specific source +* Paginated get all passages from a source +* Paginated get all documents from a source +* Attach a source to an agent +""" + + +class ListSourcesResponse(BaseModel): + sources: List[SourceModel] = Field(..., description="List of available sources") + + +class CreateSourceRequest(BaseModel): + name: str = Field(..., description="The name of the source.") + description: Optional[str] = Field(None, description="The description of the source.") + + +class CreateSourceResponse(BaseModel): + source: SourceModel = Field(..., description="The newly created source.") + + +class UploadFileToSourceRequest(BaseModel): + file: UploadFile = Field(..., description="The file to upload.") + + +class UploadFileToSourceResponse(BaseModel): + source: SourceModel = Field(..., description="The source the file was uploaded to.") + added_passages: int = Field(..., description="The number of passages added to the source.") + added_documents: int = Field(..., description="The number of documents added to the source.") + + +class GetSourcePassagesResponse(BaseModel): + passages: List[PassageModel] = Field(..., description="List of passages from the source.") + + +class GetSourceDocumentsResponse(BaseModel): + documents: List[DocumentModel] = Field(..., description="List of documents from the source.") + + +def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str): + get_current_user_with_server = partial(partial(get_current_user, server), password) + + @router.get("/sources", tags=["sources"], response_model=ListSourcesResponse) + async def list_source( + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + # Clear the interface + interface.clear() + + sources = server.ms.list_sources(user_id=user_id) + return ListSourcesResponse(sources=sources) + + @router.post("/sources", tags=["sources"], response_model=SourceModel) + async def create_source( + request: CreateSourceRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + # TODO: don't use Source and just use SourceModel once pydantic migration is complete + source = server.create_source(name=request.name, user_id=user_id) + return SourceModel( + name=source.name, + description=None, # TODO: actually store descriptions + user_id=source.user_id, + id=source.id, + embedding_config=server.server_embedding_config, + created_at=source.created_at.timestamp(), + ) + + @router.delete("/sources/{source_id}", tags=["sources"]) + async def delete_source( + source_id, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + try: + server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"}) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + @router.post("/sources/attach", tags=["sources"], response_model=SourceModel) + async def attach_source_to_agent( + agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."), + source_name: str = Query(..., description="The name of the source to attach."), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}" + assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}" + source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id) + return SourceModel( + name=source.name, + description=None, # TODO: actually store descriptions + user_id=source.user_id, + id=source.id, + embedding_config=server.server_embedding_config, + created_at=source.created_at, + ) + + @router.post("/sources/detach", tags=["sources"], response_model=SourceModel) + async def detach_source_from_agent( + agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."), + source_name: str = Query(..., description="The name of the source to detach."), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id) + + @router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse) + async def upload_file_to_source( + # file: UploadFile = UploadFile(..., description="The file to upload."), + file: UploadFile, + source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + source = server.ms.get_source(source_id=source_id, user_id=user_id) + + # create a directory connector that reads the in-memory file + connector = DirectoryConnector(input_files=[file.filename]) + + # load the data into the source via the connector + server.load_data(user_id=user_id, source_name=source.name, connector=connector) + + # TODO: actually return added passages/documents + return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0) + + @router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse) + async def list_passages( + source_id: uuid.UUID = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + raise NotImplementedError + + @router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse) + async def list_documents( + source_id: uuid.UUID = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + raise NotImplementedError + + return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index e69c5feaff..bf739208ae 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1075,13 +1075,25 @@ def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add o embedding_dim=self.config.default_embedding_config.embedding_dim, ) self.ms.create_source(source) + assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}" return source + def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID): + """Delete a data source""" + source = self.ms.get_source(source_id=source_id, user_id=user_id) + self.ms.delete_source(source_id) + + # delete data from passage store + passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + passage_store.delete({"data_source": source.name}) + + # TODO: delete data from agent passage stores (?) + def load_data( self, user_id: uuid.UUID, connector: DataConnector, - source_name: Source, + source_name: str, ): """Load data from a DataConnector into a source for a specified user_id""" # TODO: this should be implemented as a batch job or at least async, since it may take a long time @@ -1103,7 +1115,7 @@ def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source # attach a data source to an agent data_source = self.ms.get_source(source_name=source_name, user_id=user_id) if data_source is None: - raise ValueError(f"Data source {source_name} does not exist") + raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) @@ -1114,6 +1126,12 @@ def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source # attach source to agent agent.attach_source(data_source.name, source_connector, self.ms) + return data_source + + def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str): + # TODO: remove all passages coresponding to source from agent's archival memory + raise NotImplementedError + def list_attached_sources(self, agent_id: uuid.UUID): # list all attached sources to an agent return self.ms.list_attached_sources(agent_id) diff --git a/poetry.lock b/poetry.lock index 710299599d..f3e85228ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4103,6 +4103,20 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.9" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, + {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, +] + +[package.extras] +dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] + [[package]] name = "pytz" version = "2023.4" @@ -5898,4 +5912,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.10" -content-hash = "1b35809af89064c19823842ed40f5c8b9ceb8e68315fb482e2ae3b9f8cac0fad" +content-hash = "509bcb6fde67eb0c2d0a3997d6401e752d4c0cf7397a363f35167e6dc714ab0b" diff --git a/pyproject.toml b/pyproject.toml index b79a6996a4..1355eb4acf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ llama-index = "^0.10.6" llama-index-embeddings-openai = "^0.1.1" llama-index-embeddings-huggingface = {version = "^0.1.4", optional = true} llama-index-embeddings-azure-openai = "^0.1.6" +python-multipart = "^0.0.9" [tool.poetry.extras] local = ["llama-index-embeddings-huggingface"] diff --git a/tests/test_client.py b/tests/test_client.py index 6d551f3f91..f5397c7ec9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,16 +1,15 @@ import uuid +import os import time import threading +from dotenv import load_dotenv from memgpt import Admin, create_client from memgpt.constants import DEFAULT_PRESET import pytest - - import uuid - test_agent_name = f"test_client_{str(uuid.uuid4())}" # test_preset_name = "test_preset" test_preset_name = DEFAULT_PRESET @@ -60,6 +59,7 @@ def user_token(): # Fixture to create clients with different configurations @pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module") +# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module") def client(request, user_token): # use token or not if request.param["base_url"]: @@ -71,6 +71,17 @@ def client(request, user_token): yield client +# Fixture for test agent +@pytest.fixture(scope="module") +def agent(client): + agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name) + print("AGENT ID", agent_state.id) + yield agent_state + + # delete agent + client.delete_agent(agent_state.id) + + # TODO: add back once REST API supports # def test_create_preset(client): # @@ -86,26 +97,55 @@ def client(request, user_token): # client.create_preset(preset) -def test_create_agent(client): - global test_agent_state - test_agent_state = client.create_agent( - name=test_agent_name, - preset=test_preset_name, - ) - print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") - assert test_agent_state is not None - - -def test_user_message(client): - """Test that we can send a message through the client""" - assert client is not None, "Run create_agent test first" - print(f"\n\n[2] SENDING MESSAGE TO AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") - response = client.user_message(agent_id=test_agent_state.id, message="Hello my name is Test, Client Test") - assert response is not None and len(response) > 0 - - # global test_agent_state_post_message - # client.server.active_agents[0]["agent"].update_state() - # test_agent_state_post_message = client.server.active_agents[0]["agent"].agent_state - # print( - # f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}" - # ) +# def test_create_agent(client): +# global test_agent_state +# test_agent_state = client.create_agent( +# name=test_agent_name, +# preset=test_preset_name, +# ) +# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") +# assert test_agent_state is not None + + +def test_sources(client, agent): + + if not hasattr(client, "base_url"): + pytest.skip("Skipping test_sources because base_url is None") + + # list sources + sources = client.list_sources() + print("listed sources", sources) + + # create a source + source = client.create_source(name="test_source") + + # list sources + sources = client.list_sources() + print("listed sources", sources) + assert len(sources) == 1 + + # load a file into a source + filename = "CONTRIBUTING.md" + response = client.load_file_into_source(filename, source.id) + print(response) + + # attach a source + # TODO: make sure things run in the right order + client.attach_source_to_agent(source_name="test_source", agent_id=agent.id) + + # TODO: list archival memory + + # detach the source + # TODO: add when implemented + # client.detach_source(source.name, agent.id) + + # delete the source + client.delete_source(source.id) + + +# def test_user_message(client, agent): +# """Test that we can send a message through the client""" +# assert client is not None, "Run create_agent test first" +# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}") +# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test") +# assert response is not None and len(response) > 0