Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable adding files #1864

Merged
merged 26 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection

openapi_letta.json
mattzh72 marked this conversation as resolved.
Show resolved Hide resolved
openapi_openai.json

### Eclipse ###
.metadata
bin/
Expand Down
2 changes: 1 addition & 1 deletion docs/generate_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

from pydoc_markdown import PydocMarkdown

Check failure on line 3 in docs/generate_docs.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Import "pydoc_markdown" could not be resolved (reportMissingImports)
from pydoc_markdown.contrib.loaders.python import PythonLoader

Check failure on line 4 in docs/generate_docs.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Import "pydoc_markdown.contrib.loaders.python" could not be resolved (reportMissingImports)
from pydoc_markdown.contrib.processors.crossref import CrossrefProcessor

Check failure on line 5 in docs/generate_docs.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Import "pydoc_markdown.contrib.processors.crossref" could not be resolved (reportMissingImports)
from pydoc_markdown.contrib.processors.filter import FilterProcessor
from pydoc_markdown.contrib.processors.smart import SmartProcessor
from pydoc_markdown.contrib.renderers.markdown import MarkdownRenderer
Expand Down Expand Up @@ -70,7 +70,7 @@
"Message",
"Passage",
"AgentState",
"Document",
"File",
"Source",
"LLMConfig",
"EmbeddingConfig",
Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/data_connector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"outputs": [],
"source": [
"from letta.data_sources.connectors import DataConnector \n",
"from letta.schemas.document import Document\n",
"from letta.schemas.file import FileMetadata\n",
"from llama_index.core import Document as LlamaIndexDocument\n",
"from llama_index.core import SummaryIndex\n",
"from llama_index.readers.web import SimpleWebPageReader\n",
Expand Down
2 changes: 1 addition & 1 deletion letta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# imports for easier access
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.document import Document
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
Expand Down
25 changes: 18 additions & 7 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from letta.base import Base
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
Expand Down Expand Up @@ -141,7 +141,7 @@ class PassageModel(Base):
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
text = Column(String)
doc_id = Column(String)
file_id = Column(String)
agent_id = Column(String)
source_id = Column(String)

Expand All @@ -160,7 +160,7 @@ class PassageModel(Base):
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))

Index("passage_idx_user", user_id, agent_id, doc_id),
Index("passage_idx_user", user_id, agent_id, file_id),

def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
Expand All @@ -170,7 +170,7 @@ def to_record(self):
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
Expand Down Expand Up @@ -365,12 +365,17 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
if self.config.metadata_storage_uri is None:
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down Expand Up @@ -487,8 +492,14 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
raise ValueError(f"Must specify recall_storage_path in config.")
self.db_model = MessageModel
elif table_type == TableType.FILES:
self.path = self.config.metadata_storage_path
if self.path is None:
raise ValueError(f"Must specify metadata_storage_path in config.")
self.db_model = FileMetadataModel

else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down
4 changes: 2 additions & 2 deletions letta/agent_store/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PassageModel(LanceModel):
id: uuid.UUID
user_id: str
text: str
doc_id: str
file_id: str
agent_id: str
data_source: str
embedding: Vector(config.default_embedding_config.embedding_dim)
Expand All @@ -37,7 +37,7 @@ def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
Expand Down
2 changes: 1 addition & 1 deletion letta/agent_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")

# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]

def _create_collection(self):
schema = MilvusClient.create_schema(
Expand Down
2 changes: 1 addition & 1 deletion letta/agent_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
distance=models.Distance.COSINE,
),
)
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
from qdrant_client import grpc
Expand Down
22 changes: 12 additions & 10 deletions letta/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel

from letta.config import LettaConfig
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.passage import Passage
from letta.utils import printd
Expand All @@ -22,7 +22,7 @@ class TableType:
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
PASSAGES = "passages" # TODO
DOCUMENTS = "documents" # TODO
FILES = "files"


# table names used by Letta
Expand All @@ -33,17 +33,17 @@ class TableType:

# external data source tables
PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source)
FILE_TABLE_NAME = "letta_files" # original files (from source)


class StorageConnector:
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
"""Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""

type: Type[BaseModel]

def __init__(
self,
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
Expand All @@ -59,9 +59,9 @@ def __init__(
elif table_type == TableType.RECALL_MEMORY:
self.type = Message
self.table_name = RECALL_TABLE_NAME
elif table_type == TableType.DOCUMENTS:
self.type = Document
self.table_name == DOCUMENT_TABLE_NAME
elif table_type == TableType.FILES:
self.type = FileMetadata
self.table_name = FILE_TABLE_NAME
elif table_type == TableType.PASSAGES:
self.type = Passage
self.table_name = PASSAGE_TABLE_NAME
Expand All @@ -74,7 +74,7 @@ def __init__(
# agent-specific table
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
# setup base filters for user-specific tables
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
self.filters = {"user_id": self.user_id}
Expand All @@ -83,7 +83,7 @@ def __init__(

@staticmethod
def get_storage_connector(
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
Expand All @@ -92,6 +92,8 @@ def get_storage_connector(
storage_type = config.archival_storage_type
elif table_type == TableType.RECALL_MEMORY:
storage_type = config.recall_storage_type
elif table_type == TableType.FILES:
storage_type = config.metadata_storage_type
else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down
2 changes: 1 addition & 1 deletion letta/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_vector_database(
# document_store=None,
# passage_store=passage_storage,
# )
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
# print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
# except Exception as e:
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
# ms.delete_source(source_id=source.id)
Expand Down
51 changes: 51 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# new schemas
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
Expand Down Expand Up @@ -232,6 +233,9 @@ def list_sources(self) -> List[Source]:
def list_attached_sources(self, agent_id: str) -> List[Source]:
raise NotImplementedError

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
raise NotImplementedError

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
raise NotImplementedError

Expand Down Expand Up @@ -1016,6 +1020,12 @@ def get_job(self, job_id: str) -> Job:
raise ValueError(f"Failed to get job: {response.text}")
return Job(**response.json())

def delete_job(self, job_id: str) -> Job:
response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete job: {response.text}")
return Job(**response.json())

def list_jobs(self):
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
return [Job(**job) for job in response.json()]
Expand Down Expand Up @@ -1088,6 +1098,30 @@ def list_attached_sources(self, agent_id: str) -> List[Source]:
raise ValueError(f"Failed to list attached sources: {response.text}")
return [Source(**source) for source in response.json()]

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source with pagination support.

Args:
source_id (str): ID of the source
limit (int): Number of files to return
cursor (Optional[str]): Pagination cursor for fetching the next page

Returns:
List[FileMetadata]: List of files
"""
# Prepare query parameters for pagination
params = {"limit": limit, "cursor": cursor}

# Make the request to the FastAPI endpoint
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)

if response.status_code != 200:
raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")

# Parse the JSON response
return [FileMetadata(**metadata) for metadata in response.json()]

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source
Expand Down Expand Up @@ -2162,6 +2196,9 @@ def load_file_into_source(self, filename: str, source_id: str, blocking=True):
def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)

def delete_job(self, job_id: str):
return self.server.delete_job(job_id)

def list_jobs(self):
return self.server.list_jobs(user_id=self.user_id)

Expand Down Expand Up @@ -2261,6 +2298,20 @@ def list_attached_sources(self, agent_id: str) -> List[Source]:
"""
return self.server.list_attached_sources(agent_id=agent_id)

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source.

Args:
source_id (str): ID of the source
limit (int): The # of items to return
cursor (str): The cursor for fetching the next page

Returns:
files (List[FileMetadata]): List of files
"""
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source
Expand Down
Loading
Loading