From ade6fc0fe235beb4e6098a1e2d814539bfe4a74c Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 9 Oct 2024 17:45:43 -0700 Subject: [PATCH 01/23] Modify gitignore to ignore auto generated openapi files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ddcdb97a72..1fcffd8a4c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection # Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection +openapi_letta.json +openapi_openai.json + ### Eclipse ### .metadata bin/ From ab56a087a05866b957026f7a11969a9e8cec5bab Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 9 Oct 2024 17:47:22 -0700 Subject: [PATCH 02/23] Add delete job to client --- letta/client/client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/letta/client/client.py b/letta/client/client.py index 6e43601aef..88e89521fd 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1016,6 +1016,12 @@ def get_job(self, job_id: str) -> Job: raise ValueError(f"Failed to get job: {response.text}") return Job(**response.json()) + def delete_job(self, job_id: str) -> Job: + response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to delete job: {response.text}") + return Job(**response.json()) + def list_jobs(self): response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers) return [Job(**job) for job in response.json()] @@ -2162,6 +2168,9 @@ def load_file_into_source(self, filename: str, source_id: str, blocking=True): def get_job(self, job_id: str): return self.server.get_job(job_id=job_id) + def delete_job(self, job_id: str): + return self.server.delete_job(job_id) + def list_jobs(self): return self.server.list_jobs(user_id=self.user_id) From ea254ca4257940f27ea48ad0f320722e00d686d0 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 9 Oct 2024 17:48:22 -0700 Subject: [PATCH 03/23] Add delete jobs route --- letta/server/rest_api/routers/v1/jobs.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index bd581a98ef..3f3fef17cb 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, Header, Query +from fastapi import APIRouter, Depends, Header, HTTPException, Query from letta.schemas.job import Job from letta.server.rest_api.utils import get_letta_server @@ -54,3 +54,19 @@ def get_job( """ return server.get_job(job_id=job_id) + + +@router.delete("/{job_id}", response_model=Job, operation_id="delete_job") +def delete_job( + job_id: str, + server: "SyncServer" = Depends(get_letta_server), +): + """ + Delete a job by its job_id. + """ + job = server.get_job(job_id=job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + server.delete_job(job_id=job_id) + return job From 89d8adc5e100ba024b2c624077bc484f3756553f Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 9 Oct 2024 17:52:16 -0700 Subject: [PATCH 04/23] Modify test_client --- tests/test_client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index b81a0acd51..aac72e94db 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -305,6 +305,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): for source in client.list_sources(): client.delete_source(source.id) + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + # list sources sources = client.list_sources() print("listed sources", sources) From 761557dd4fca812779747cdc4017b783989631b9 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 13:38:48 -0700 Subject: [PATCH 05/23] Finish adding documents table --- letta/agent_store/db.py | 49 +++++++++------------ letta/agent_store/storage.py | 6 ++- letta/client/client.py | 31 +++++++++++++ letta/config.py | 5 +++ letta/data_sources/connectors.py | 15 +++---- letta/metadata.py | 35 +++++++++++++++ letta/schemas/document.py | 10 ++++- letta/server/rest_api/routers/v1/sources.py | 10 ++--- letta/server/server.py | 14 +++--- tests/test_client.py | 48 ++++++++++++++++++++ 10 files changed, 170 insertions(+), 53 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 585de6edee..dbaaaa3a5f 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -27,7 +27,7 @@ from letta.agent_store.storage import StorageConnector, TableType from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.metadata import EmbeddingConfigColumn, ToolCallColumn +from letta.metadata import DocumentModel, EmbeddingConfigColumn, ToolCallColumn # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message @@ -365,12 +365,17 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) self.uri = self.config.archival_storage_uri self.db_model = PassageModel if self.config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}") + raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}") elif table_type == TableType.RECALL_MEMORY: self.uri = self.config.recall_storage_uri self.db_model = MessageModel if self.config.recall_storage_uri is None: - raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}") + raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") + elif table_type == TableType.DOCUMENTS: + self.uri = self.config.documents_storage_uri + self.db_model = DocumentModel + if self.config.documents_storage_uri is None: + raise ValueError(f"Must specify documents_storage_uri in config {self.config.config_path}") else: raise ValueError(f"Table type {table_type} not implemented") @@ -398,6 +403,8 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Op return records def insert_many(self, records, exists_ok=True, show_progress=False): + pass + # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) if len(records) == 0: return @@ -487,8 +494,14 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) # TODO: eventually implement URI option self.path = self.config.recall_storage_path if self.path is None: - raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}") + raise ValueError(f"Must specify recall_storage_path in config.") self.db_model = MessageModel + elif table_type == TableType.DOCUMENTS: + self.path = self.config.documents_storage_path + if self.path is None: + raise ValueError(f"Must specify documents_storage_path in config.") + self.db_model = DocumentModel + else: raise ValueError(f"Table type {table_type} not implemented") @@ -504,36 +517,18 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) def insert_many(self, records, exists_ok=True, show_progress=False): + pass + # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) if len(records) == 0: return - - added_ids = [] # avoid adding duplicates - # NOTE: this has not great performance due to the excessive commits with self.session_maker() as session: iterable = tqdm(records) if show_progress else records for record in iterable: # db_record = self.db_model(**vars(record)) - - if record.id in added_ids: - continue - - existing_record = session.query(self.db_model).filter_by(id=record.id).first() - if existing_record: - if exists_ok: - fields = record.model_dump() - fields.pop("id") - session.query(self.db_model).filter(self.db_model.id == record.id).update(fields) - session.commit() - else: - raise ValueError(f"Record with id {record.id} already exists.") - - else: - db_record = self.db_model(**record.dict()) - session.add(db_record) - session.commit() - - added_ids.append(record.id) + db_record = self.db_model(**record.dict()) + session.add(db_record) + session.commit() def insert(self, record, exists_ok=True): self.insert_many([record], exists_ok=exists_ok) diff --git a/letta/agent_store/storage.py b/letta/agent_store/storage.py index 7412010de8..785db6b613 100644 --- a/letta/agent_store/storage.py +++ b/letta/agent_store/storage.py @@ -22,7 +22,7 @@ class TableType: ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id} RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id} PASSAGES = "passages" # TODO - DOCUMENTS = "documents" # TODO + DOCUMENTS = "documents" # table names used by Letta @@ -61,7 +61,7 @@ def __init__( self.table_name = RECALL_TABLE_NAME elif table_type == TableType.DOCUMENTS: self.type = Document - self.table_name == DOCUMENT_TABLE_NAME + self.table_name = DOCUMENT_TABLE_NAME elif table_type == TableType.PASSAGES: self.type = Passage self.table_name = PASSAGE_TABLE_NAME @@ -92,6 +92,8 @@ def get_storage_connector( storage_type = config.archival_storage_type elif table_type == TableType.RECALL_MEMORY: storage_type = config.recall_storage_type + elif table_type == TableType.DOCUMENTS: + storage_type = config.documents_storage_type else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/letta/client/client.py b/letta/client/client.py index 88e89521fd..f9cbaccce9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -21,6 +21,7 @@ UpdateHuman, UpdatePersona, ) +from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig # new schemas @@ -232,6 +233,9 @@ def list_sources(self) -> List[Source]: def list_attached_sources(self, agent_id: str) -> List[Source]: raise NotImplementedError + def list_documents_from_source(self, source_id: str) -> List[Document]: + raise NotImplementedError + def update_source(self, source_id: str, name: Optional[str] = None) -> Source: raise NotImplementedError @@ -1094,6 +1098,21 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: raise ValueError(f"Failed to list attached sources: {response.text}") return [Source(**source) for source in response.json()] + def list_documents_from_source(self, source_id: str) -> List[Document]: + """ + List documents from source. + + Args: + source_id (str): ID of the source + + Returns: + documents (List[Document]): List of documents + """ + response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/documents", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to list documents with source id {source_id}: [{response.status_code}] {response.text}") + return [Document(**document) for document in response.json()] + def update_source(self, source_id: str, name: Optional[str] = None) -> Source: """ Update a source @@ -2270,6 +2289,18 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: """ return self.server.list_attached_sources(agent_id=agent_id) + def list_documents_from_source(self, source_id: str) -> List[Document]: + """ + List documents from source. + + Args: + source_id (str): ID of the source + + Returns: + documents (List[Document]): List of documents + """ + return self.server.list_documents_from_source(source_id=source_id) + def update_source(self, source_id: str, name: Optional[str] = None) -> Source: """ Update a source diff --git a/letta/config.py b/letta/config.py index b07c7f4cbc..e03bac9bfd 100644 --- a/letta/config.py +++ b/letta/config.py @@ -78,6 +78,11 @@ class LettaConfig: metadata_storage_path: str = LETTA_DIR metadata_storage_uri: str = None + # database configs: document storage + documents_storage_type: str = "sqlite" + documents_storage_path: str = LETTA_DIR + documents_storage_uri: str = None + # database configs: agent state persistence_manager_type: str = None # in-memory, db persistence_manager_save_file: str = None # local file diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 2ff66519af..731850528b 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Tuple import typer from llama_index.core import Document as LlamaIndexDocument @@ -41,7 +41,7 @@ def load_data( connector: DataConnector, source: Source, passage_store: StorageConnector, - document_store: Optional[StorageConnector] = None, + document_store: StorageConnector, ): """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" embedding_config = source.embedding_config @@ -57,14 +57,14 @@ def load_data( for document_text, document_metadata in connector.generate_documents(): # insert document into storage document = Document( + id=create_uuid_from_string(f"{str(source.id)}_{document_text}"), + user_id=source.user_id, + source_id=source.id, text=document_text, metadata_=document_metadata, - source_id=source.id, - user_id=source.user_id, ) document_count += 1 - if document_store: - document_store.insert(document) + document_store.insert(document) # generate passages for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size): @@ -158,9 +158,6 @@ def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Docum llama_index_docs = reader.load_data(show_progress=True) for llama_index_doc in llama_index_docs: - # TODO: add additional metadata? - # doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata) - # docs.append(doc) yield llama_index_doc.text, llama_index_doc.metadata def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: diff --git a/letta/metadata.py b/letta/metadata.py index 3e56fddbe3..301fe2c51f 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -12,6 +12,7 @@ DateTime, Index, String, + Text, TypeDecorator, desc, func, @@ -23,6 +24,7 @@ from letta.schemas.agent import AgentState from letta.schemas.api_key import APIKey from letta.schemas.block import Block, Human, Persona +from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus from letta.schemas.job import Job @@ -39,6 +41,33 @@ Base = declarative_base() +class DocumentModel(Base): + __tablename__ = "documents" + __table_args__ = {"extend_existing": True} + + id = Column(String, primary_key=True, nullable=False) + user_id = Column(String, nullable=False) + # TODO: Investigate why this breaks during table creation due to FK + # source_id = Column(String, ForeignKey("sources.id"), nullable=False) + source_id = Column(String, nullable=False) + text = Column(Text, nullable=False) # The text of the document + metadata_ = Column(JSON, nullable=True) # Any additional metadata + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + def __repr__(self): + return f"" + + def to_record(self): + return Document( + id=self.id, + user_id=self.user_id, + source_id=self.source_id, + text=self.text, + metadata_=self.metadata_, + created_at=self.created_at, + ) + + class LLMConfigColumn(TypeDecorator): """Custom type for storing LLMConfig as JSON""" @@ -866,6 +895,12 @@ def create_job(self, job: Job): session.add(JobModel(**vars(job))) session.commit() + @enforce_types + def list_documents_from_source(self, source_id: str): + with self.session_maker() as session: + results = session.query(DocumentModel).filter(DocumentModel.source_id == source_id).all() + return [r.to_record() for r in results] + def delete_job(self, job_id: str): with self.session_maker() as session: session.query(JobModel).filter(JobModel.id == job_id).delete() diff --git a/letta/schemas/document.py b/letta/schemas/document.py index 2628ac5702..a262a35903 100644 --- a/letta/schemas/document.py +++ b/letta/schemas/document.py @@ -1,8 +1,10 @@ +from datetime import datetime from typing import Dict, Optional from pydantic import Field from letta.schemas.letta_base import LettaBase +from letta.utils import get_utc_time class DocumentBase(LettaBase): @@ -15,7 +17,11 @@ class Document(DocumentBase): """Representation of a single document (broken up into `Passage` objects)""" id: str = DocumentBase.generate_id_field() - text: str = Field(..., description="The text of the document.") - source_id: str = Field(..., description="The unique identifier of the source associated with the document.") user_id: str = Field(description="The unique identifier of the user associated with the document.") + source_id: str = Field(..., description="The unique identifier of the source associated with the document.") + text: str = Field(..., description="The text of the document.") metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") + created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.") + + class Config: + extra = "allow" diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index a59abd31bb..6426b34940 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -186,18 +186,16 @@ def list_passages( return passages -@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_source_documents") -def list_documents( +@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_documents_from_source") +def list_documents_from_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all documents associated with a data source. """ - actor = server.get_user_or_default(user_id=user_id) - - documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id) + # return [] + documents = server.list_documents_from_source(source_id=source_id) return documents diff --git a/letta/server/server.py b/letta/server/server.py index d9036f26a0..991c24d5fd 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -151,7 +151,7 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base, sessionmaker -from letta.agent_store.db import MessageModel, PassageModel +from letta.agent_store.db import DocumentModel, MessageModel, PassageModel from letta.config import LettaConfig # NOTE: hack to see if single session management works @@ -197,6 +197,7 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N JobModel.__table__, PassageModel.__table__, MessageModel.__table__, + DocumentModel.__table__, OrganizationModel.__table__, ], ) @@ -1573,8 +1574,7 @@ def load_data( # get the data connectors passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) - # TODO: add document store support - document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + document_store = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) # load data into the document store passage_count, document_count = load_data(connector, source, passage_store, document_store) @@ -1629,14 +1629,14 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: # list all attached sources to an agent return self.ms.list_attached_sources(agent_id) + def list_documents_from_source(self, source_id: str) -> List[Document]: + # list all attached sources to an agent + return self.ms.list_documents_from_source(source_id=source_id) + def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]: - warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning) - return [] - def list_all_sources(self, user_id: str) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" diff --git a/tests/test_client.py b/tests/test_client.py index aac72e94db..7f56b1a4dc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -298,6 +298,54 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): # print("CONFIG", config_response) +def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load a file into a source (non-blocking job) + filename = "tests/data/memgpt_paper.pdf" + upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False) + print("Upload job", upload_job, upload_job.status, upload_job.metadata_) + + # view active jobs + active_jobs = client.list_active_jobs() + jobs = client.list_jobs() + assert upload_job.id in [j.id for j in jobs] + assert len(active_jobs) == 1 + assert active_jobs[0].metadata_["source_id"] == source.id + + # wait for job to finish (with timeout) + timeout = 120 + start_time = time.time() + while True: + status = client.get_job(upload_job.id).status + print(f"\r{status}", end="", flush=True) + if status == JobStatus.completed: + break + time.sleep(1) + if time.time() - start_time > timeout: + raise ValueError("Job did not finish in time") + + # Get the documents + documents = client.list_documents_from_source(source.id) + assert len(documents) == 13 # 13 pages + + # Get the memgpt paper + document = documents[0] + assert document.metadata_.get("file_name", None) == "memgpt_paper.pdf" + assert document.source_id == source.id + + def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() From 7303e86a71d16e77703cebbd72022698aefa5929 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 13:46:56 -0700 Subject: [PATCH 06/23] Merge main --- letta/agent_store/db.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index dbaaaa3a5f..078909d38b 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -517,18 +517,36 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) def insert_many(self, records, exists_ok=True, show_progress=False): - pass - # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) if len(records) == 0: return + + added_ids = [] # avoid adding duplicates + # NOTE: this has not great performance due to the excessive commits with self.session_maker() as session: iterable = tqdm(records) if show_progress else records for record in iterable: # db_record = self.db_model(**vars(record)) - db_record = self.db_model(**record.dict()) - session.add(db_record) - session.commit() + + if record.id in added_ids: + continue + + existing_record = session.query(self.db_model).filter_by(id=record.id).first() + if existing_record: + if exists_ok: + fields = record.model_dump() + fields.pop("id") + session.query(self.db_model).filter(self.db_model.id == record.id).update(fields) + session.commit() + else: + raise ValueError(f"Record with id {record.id} already exists.") + + else: + db_record = self.db_model(**record.dict()) + session.add(db_record) + session.commit() + + added_ids.append(record.id) def insert(self, record, exists_ok=True): self.insert_many([record], exists_ok=exists_ok) From aec019d9e4ba16c9fd052a518b9f6b51a611c0eb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 13:47:31 -0700 Subject: [PATCH 07/23] Merge main --- letta/agent_store/db.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 078909d38b..d84a8ebdda 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -403,8 +403,6 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Op return records def insert_many(self, records, exists_ok=True, show_progress=False): - pass - # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) if len(records) == 0: return From 99d46a414ddd57f632b61a6b9542b5adea883d51 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 15:04:53 -0700 Subject: [PATCH 08/23] Refactor documents to files --- docs/data_sources.md | 4 +- docs/generate_docs.py | 2 +- examples/notebooks/data_connector.ipynb | 2 +- letta/__init__.py | 2 +- letta/agent_store/db.py | 10 +-- letta/agent_store/storage.py | 20 ++--- letta/cli/cli_load.py | 2 +- letta/client/client.py | 24 +++--- letta/data_sources/connectors.py | 85 ++++++++++++--------- letta/metadata.py | 17 ++--- letta/schemas/{document.py => file.py} | 7 +- letta/schemas/job.py | 2 +- letta/schemas/source.py | 4 +- letta/server/rest_api/routers/v1/sources.py | 12 +-- letta/server/server.py | 20 ++--- tests/test_client.py | 4 +- tests/utils.py | 9 +-- 17 files changed, 115 insertions(+), 111 deletions(-) rename letta/schemas/{document.py => file.py} (81%) diff --git a/docs/data_sources.md b/docs/data_sources.md index a0d74b31f8..a0567207c6 100644 --- a/docs/data_sources.md +++ b/docs/data_sources.md @@ -102,7 +102,7 @@ class DummyDataConnector(DataConnector): for text in self.texts: yield text, {"metadata": "dummy"} - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + def generate_passages(self, document_text: str, documents: File, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: for doc in documents: - yield doc.text, doc.metadata + yield document_text, doc.metadata ``` diff --git a/docs/generate_docs.py b/docs/generate_docs.py index c315ad4ce5..8db3bfb20e 100644 --- a/docs/generate_docs.py +++ b/docs/generate_docs.py @@ -70,7 +70,7 @@ def generate_modules(config): "Message", "Passage", "AgentState", - "Document", + "File", "Source", "LLMConfig", "EmbeddingConfig", diff --git a/examples/notebooks/data_connector.ipynb b/examples/notebooks/data_connector.ipynb index bd26987a66..d16033b7e5 100644 --- a/examples/notebooks/data_connector.ipynb +++ b/examples/notebooks/data_connector.ipynb @@ -270,7 +270,7 @@ "outputs": [], "source": [ "from letta.data_sources.connectors import DataConnector \n", - "from letta.schemas.document import Document\n", + "from letta.schemas.document import File\n", "from llama_index.core import Document as LlamaIndexDocument\n", "from llama_index.core import SummaryIndex\n", "from llama_index.readers.web import SimpleWebPageReader\n", diff --git a/letta/__init__.py b/letta/__init__.py index bc2004172b..87d434ffea 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -7,9 +7,9 @@ # imports for easier access from letta.schemas.agent import AgentState from letta.schemas.block import Block -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus +from letta.schemas.file import File from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index d84a8ebdda..b1af67a936 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -27,7 +27,7 @@ from letta.agent_store.storage import StorageConnector, TableType from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.metadata import DocumentModel, EmbeddingConfigColumn, ToolCallColumn +from letta.metadata import EmbeddingConfigColumn, FileModel, ToolCallColumn # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message @@ -371,9 +371,9 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) self.db_model = MessageModel if self.config.recall_storage_uri is None: raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") - elif table_type == TableType.DOCUMENTS: + elif table_type == TableType.FILES: self.uri = self.config.documents_storage_uri - self.db_model = DocumentModel + self.db_model = FileModel if self.config.documents_storage_uri is None: raise ValueError(f"Must specify documents_storage_uri in config {self.config.config_path}") else: @@ -494,11 +494,11 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) if self.path is None: raise ValueError(f"Must specify recall_storage_path in config.") self.db_model = MessageModel - elif table_type == TableType.DOCUMENTS: + elif table_type == TableType.FILES: self.path = self.config.documents_storage_path if self.path is None: raise ValueError(f"Must specify documents_storage_path in config.") - self.db_model = DocumentModel + self.db_model = FileModel else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/letta/agent_store/storage.py b/letta/agent_store/storage.py index 785db6b613..a2092e1c70 100644 --- a/letta/agent_store/storage.py +++ b/letta/agent_store/storage.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from letta.config import LettaConfig -from letta.schemas.document import Document +from letta.schemas.file import File from letta.schemas.message import Message from letta.schemas.passage import Passage from letta.utils import printd @@ -22,7 +22,7 @@ class TableType: ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id} RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id} PASSAGES = "passages" # TODO - DOCUMENTS = "documents" + FILES = "files" # table names used by Letta @@ -33,17 +33,17 @@ class TableType: # external data source tables PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source) -DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source) +DOCUMENT_TABLE_NAME = "letta_documents" # original files (from source) class StorageConnector: - """Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory""" + """Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory""" type: Type[BaseModel] def __init__( self, - table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS], + table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES], config: LettaConfig, user_id, agent_id=None, @@ -59,8 +59,8 @@ def __init__( elif table_type == TableType.RECALL_MEMORY: self.type = Message self.table_name = RECALL_TABLE_NAME - elif table_type == TableType.DOCUMENTS: - self.type = Document + elif table_type == TableType.FILES: + self.type = File self.table_name = DOCUMENT_TABLE_NAME elif table_type == TableType.PASSAGES: self.type = Passage @@ -74,7 +74,7 @@ def __init__( # agent-specific table assert agent_id is not None, "Agent ID must be provided for agent-specific tables" self.filters = {"user_id": self.user_id, "agent_id": self.agent_id} - elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS: + elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES: # setup base filters for user-specific tables assert agent_id is None, "Agent ID must not be provided for user-specific tables" self.filters = {"user_id": self.user_id} @@ -83,7 +83,7 @@ def __init__( @staticmethod def get_storage_connector( - table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS], + table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES], config: LettaConfig, user_id, agent_id=None, @@ -92,7 +92,7 @@ def get_storage_connector( storage_type = config.archival_storage_type elif table_type == TableType.RECALL_MEMORY: storage_type = config.recall_storage_type - elif table_type == TableType.DOCUMENTS: + elif table_type == TableType.FILES: storage_type = config.documents_storage_type else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py index eaef8d697e..61518bc038 100644 --- a/letta/cli/cli_load.py +++ b/letta/cli/cli_load.py @@ -106,7 +106,7 @@ def load_vector_database( # document_store=None, # passage_store=passage_storage, # ) - # print(f"Loaded {num_passages} passages and {num_documents} documents from {name}") + # print(f"Loaded {num_passages} passages and {num_documents} files from {name}") # except Exception as e: # typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) # ms.delete_source(source_id=source.id) diff --git a/letta/client/client.py b/letta/client/client.py index f9cbaccce9..9d83a60060 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -21,11 +21,11 @@ UpdateHuman, UpdatePersona, ) -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig # new schemas from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.file import File from letta.schemas.job import Job from letta.schemas.letta_request import LettaRequest from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse @@ -233,7 +233,7 @@ def list_sources(self) -> List[Source]: def list_attached_sources(self, agent_id: str) -> List[Source]: raise NotImplementedError - def list_documents_from_source(self, source_id: str) -> List[Document]: + def list_files_from_source(self, source_id: str) -> List[File]: raise NotImplementedError def update_source(self, source_id: str, name: Optional[str] = None) -> Source: @@ -1098,20 +1098,20 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: raise ValueError(f"Failed to list attached sources: {response.text}") return [Source(**source) for source in response.json()] - def list_documents_from_source(self, source_id: str) -> List[Document]: + def list_files_from_source(self, source_id: str) -> List[File]: """ - List documents from source. + List files from source. Args: source_id (str): ID of the source Returns: - documents (List[Document]): List of documents + files (List[File]): List of files """ - response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/documents", headers=self.headers) + response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers) if response.status_code != 200: - raise ValueError(f"Failed to list documents with source id {source_id}: [{response.status_code}] {response.text}") - return [Document(**document) for document in response.json()] + raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}") + return [File(**file) for file in response.json()] def update_source(self, source_id: str, name: Optional[str] = None) -> Source: """ @@ -2289,17 +2289,17 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: """ return self.server.list_attached_sources(agent_id=agent_id) - def list_documents_from_source(self, source_id: str) -> List[Document]: + def list_files_from_source(self, source_id: str) -> List[File]: """ - List documents from source. + List files from source. Args: source_id (str): ID of the source Returns: - documents (List[Document]): List of documents + files (List[File]): List of files """ - return self.server.list_documents_from_source(source_id=source_id) + return self.server.list_files_from_source(source_id=source_id) def update_source(self, source_id: str, name: Optional[str] = None) -> Source: """ diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 731850528b..a4ab1cb5be 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -2,10 +2,11 @@ import typer from llama_index.core import Document as LlamaIndexDocument +from llama_index.readers.file import PDFReader from letta.agent_store.storage import StorageConnector from letta.embeddings import embedding_model -from letta.schemas.document import Document +from letta.schemas.file import File from letta.schemas.passage import Passage from letta.schemas.source import Source from letta.utils import create_uuid_from_string @@ -13,23 +14,24 @@ class DataConnector: """ - Base class for data connectors that can be extended to generate documents and passages from a custom data source. + Base class for data connectors that can be extended to generate files and passages from a custom data source. """ - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: + def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[File]: """ Generate document text and metadata from a data source. Returns: - documents (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document. + files (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document. """ - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: + def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: """ - Generate passage text and metadata from a list of documents. + Generate passage text and metadata from a list of files. Args: - documents (List[Document]): List of documents to generate passages from. + file_text (str): The text of the document + file (File): The document to generate passages from. chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024. Returns: @@ -43,31 +45,32 @@ def load_data( passage_store: StorageConnector, document_store: StorageConnector, ): - """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" + """Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id.""" embedding_config = source.embedding_config # embedding model embed_model = embedding_model(embedding_config) - # insert passages/documents + # insert passages/file passages = [] embedding_to_document_name = {} passage_count = 0 document_count = 0 - for document_text, document_metadata in connector.generate_documents(): + for file_text, document_metadata in connector.generate_files(): # insert document into storage - document = Document( - id=create_uuid_from_string(f"{str(source.id)}_{document_text}"), + file = File( + id=create_uuid_from_string(f"{str(source.id)}_{file_text}"), user_id=source.user_id, source_id=source.id, - text=document_text, metadata_=document_metadata, ) document_count += 1 - document_store.insert(document) + document_store.insert(file) # generate passages - for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size): + for passage_text, passage_metadata in connector.generate_passages( + file_text, file, chunk_size=embedding_config.embedding_chunk_size + ): # for some reason, llama index parsers sometimes return empty strings if len(passage_text) == 0: typer.secho( @@ -89,7 +92,7 @@ def load_data( passage = Passage( id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"), text=passage_text, - doc_id=document.id, + doc_id=file.id, source_id=source.id, metadata_=passage_metadata, user_id=source.user_id, @@ -98,16 +101,16 @@ def load_data( ) hashable_embedding = tuple(passage.embedding) - document_name = document.metadata_.get("file_path", document.id) + file_name = file.metadata_.get("file_path", file.id) if hashable_embedding in embedding_to_document_name: typer.secho( - f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", + f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", fg=typer.colors.YELLOW, ) continue passages.append(passage) - embedding_to_document_name[hashable_embedding] = document_name + embedding_to_document_name[hashable_embedding] = file_name if len(passages) >= 100: # insert passages into passage store passage_store.insert_many(passages) @@ -143,38 +146,44 @@ def __init__(self, input_files: List[str] = None, input_directory: str = None, r if self.recursive == True: assert self.input_directory is not None, "Must provide input directory if recursive is True." - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: + def generate_files(self) -> Iterator[Tuple[str, Dict]]: from llama_index.core import SimpleDirectoryReader + # We need to hijack the file extractor here + # The default behavior is to split up the PDF by page, but we want to return the full document + file_extractor = { + ".pdf": PDFReader(return_full_document=True), + } + if self.input_directory is not None: reader = SimpleDirectoryReader( input_dir=self.input_directory, recursive=self.recursive, required_exts=[ext.strip() for ext in str(self.extensions).split(",")], + exclude=["*png", "*jpg", "*jpeg"], # Don't support images for now + file_extractor=file_extractor, ) else: assert self.input_files is not None, "Must provide input files if input_dir is None" - reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files]) + reader = SimpleDirectoryReader( + input_files=[str(f) for f in self.input_files], + exclude=["*png", "*jpg", "*jpeg"], # Don't support images for now + file_extractor=file_extractor, + ) llama_index_docs = reader.load_data(show_progress=True) for llama_index_doc in llama_index_docs: yield llama_index_doc.text, llama_index_doc.metadata - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: + def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: # use llama index to run embeddings code - # from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import TokenTextSplitter parser = TokenTextSplitter(chunk_size=chunk_size) - for document in documents: - llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)] - nodes = parser.get_nodes_from_documents(llama_index_docs) - for node in nodes: - # passage = Passage( - # text=node.text, - # doc_id=document.id, - # ) - yield node.text, None + llama_index_docs = [LlamaIndexDocument(text=file_text, metadata=file.metadata_)] + nodes = parser.get_nodes_from_documents(llama_index_docs) + for node in nodes: + yield node.text, None class WebConnector(DirectoryConnector): @@ -182,17 +191,17 @@ def __init__(self, urls: List[str] = None, html_to_text: bool = True): self.urls = urls self.html_to_text = html_to_text - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: + def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: from llama_index.readers.web import SimpleWebPageReader - documents = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls) - for document in documents: + files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls) + for document in files: yield document.text, {"url": document.id_} class VectorDBConnector(DataConnector): # NOTE: this class has not been properly tested, so is unlikely to work - # TODO: allow loading multiple tables (1:1 mapping between Document and Table) + # TODO: allow loading multiple tables (1:1 mapping between File and Table) def __init__( self, @@ -215,10 +224,10 @@ def __init__( self.engine = create_engine(uri) - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: + def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: yield self.table_name, None - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: + def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: from pgvector.sqlalchemy import Vector from sqlalchemy import Inspector, MetaData, Table, select diff --git a/letta/metadata.py b/letta/metadata.py index 301fe2c51f..10adf7cd6d 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -12,7 +12,6 @@ DateTime, Index, String, - Text, TypeDecorator, desc, func, @@ -24,9 +23,9 @@ from letta.schemas.agent import AgentState from letta.schemas.api_key import APIKey from letta.schemas.block import Block, Human, Persona -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus +from letta.schemas.file import File from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -41,8 +40,8 @@ Base = declarative_base() -class DocumentModel(Base): - __tablename__ = "documents" +class FileModel(Base): + __tablename__ = "files" __table_args__ = {"extend_existing": True} id = Column(String, primary_key=True, nullable=False) @@ -50,19 +49,17 @@ class DocumentModel(Base): # TODO: Investigate why this breaks during table creation due to FK # source_id = Column(String, ForeignKey("sources.id"), nullable=False) source_id = Column(String, nullable=False) - text = Column(Text, nullable=False) # The text of the document metadata_ = Column(JSON, nullable=True) # Any additional metadata created_at = Column(DateTime(timezone=True), server_default=func.now()) def __repr__(self): - return f"" + return f"" def to_record(self): - return Document( + return File( id=self.id, user_id=self.user_id, source_id=self.source_id, - text=self.text, metadata_=self.metadata_, created_at=self.created_at, ) @@ -896,9 +893,9 @@ def create_job(self, job: Job): session.commit() @enforce_types - def list_documents_from_source(self, source_id: str): + def list_files_from_source(self, source_id: str): with self.session_maker() as session: - results = session.query(DocumentModel).filter(DocumentModel.source_id == source_id).all() + results = session.query(FileModel).filter(FileModel.source_id == source_id).all() return [r.to_record() for r in results] def delete_job(self, job_id: str): diff --git a/letta/schemas/document.py b/letta/schemas/file.py similarity index 81% rename from letta/schemas/document.py rename to letta/schemas/file.py index a262a35903..c717b8ad20 100644 --- a/letta/schemas/document.py +++ b/letta/schemas/file.py @@ -7,19 +7,18 @@ from letta.utils import get_utc_time -class DocumentBase(LettaBase): +class FileBase(LettaBase): """Base class for document schemas""" __id_prefix__ = "doc" -class Document(DocumentBase): +class File(FileBase): """Representation of a single document (broken up into `Passage` objects)""" - id: str = DocumentBase.generate_id_field() + id: str = FileBase.generate_id_field() user_id: str = Field(description="The unique identifier of the user associated with the document.") source_id: str = Field(..., description="The unique identifier of the source associated with the document.") - text: str = Field(..., description="The text of the document.") metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.") diff --git a/letta/schemas/job.py b/letta/schemas/job.py index da83d4be3d..4499c167cd 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -15,7 +15,7 @@ class JobBase(LettaBase): class Job(JobBase): """ - Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding documents). + Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding files). Parameters: id (str): The unique identifier of the job. diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 827ebb9f51..8f816ad701 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -28,7 +28,7 @@ class SourceCreate(BaseSource): class Source(BaseSource): """ - Representation of a source, which is a collection of documents and passages. + Representation of a source, which is a collection of files and passages. Parameters: id (str): The ID of the source @@ -59,4 +59,4 @@ class UploadFileToSourceRequest(BaseModel): class UploadFileToSourceResponse(BaseModel): source: Source = Field(..., description="The source the file was uploaded to.") added_passages: int = Field(..., description="The number of passages added to the source.") - added_documents: int = Field(..., description="The number of documents added to the source.") + added_documents: int = Field(..., description="The number of files added to the source.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 6426b34940..ecaae56f9a 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile -from letta.schemas.document import Document +from letta.schemas.file import File from letta.schemas.job import Job from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -186,17 +186,17 @@ def list_passages( return passages -@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_documents_from_source") -def list_documents_from_source( +@router.get("/{source_id}/files", response_model=List[File], operation_id="list_files_from_source") +def list_files_from_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), ): """ - List all documents associated with a data source. + List all files associated with a data source. """ # return [] - documents = server.list_documents_from_source(source_id=source_id) - return documents + files = server.list_files_from_source(source_id=source_id) + return files def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): diff --git a/letta/server/server.py b/letta/server/server.py index 72c155d3d3..82bba1437b 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -61,11 +61,11 @@ CreatePersona, UpdateBlock, ) -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus +from letta.schemas.file import File from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig @@ -151,7 +151,7 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base, sessionmaker -from letta.agent_store.db import DocumentModel, MessageModel, PassageModel +from letta.agent_store.db import FileModel, MessageModel, PassageModel from letta.config import LettaConfig # NOTE: hack to see if single session management works @@ -197,7 +197,7 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N JobModel.__table__, PassageModel.__table__, MessageModel.__table__, - DocumentModel.__table__, + FileModel.__table__, OrganizationModel.__table__, ], ) @@ -1545,7 +1545,7 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Jo # job.status = JobStatus.failed # job.metadata_["error"] = error # self.ms.update_job(job) - # # TODO: delete any associated passages/documents? + # # TODO: delete any associated passages/files? # # return failed job # return job @@ -1574,7 +1574,7 @@ def load_data( # get the data connectors passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) - document_store = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + document_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) # load data into the document store passage_count, document_count = load_data(connector, source, passage_store, document_store) @@ -1634,9 +1634,9 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: # list all attached sources to an agent return self.ms.list_attached_sources(agent_id) - def list_documents_from_source(self, source_id: str) -> List[Document]: + def list_files_from_source(self, source_id: str) -> List[File]: # list all attached sources to an agent - return self.ms.list_documents_from_source(source_id=source_id) + return self.ms.list_files_from_source(source_id=source_id) def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) @@ -1655,9 +1655,9 @@ def list_all_sources(self, user_id: str) -> List[Source]: passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) num_passages = passage_conn.size({"source_id": source.id}) - # TODO: add when documents table implemented - ## count number of documents - # document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + # TODO: add when files table implemented + ## count number of files + # document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) # num_documents = document_conn.size({"data_source": source.name}) num_documents = 0 diff --git a/tests/test_client.py b/tests/test_client.py index bd3079fa60..9ad17b0df4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -337,8 +337,8 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): raise ValueError("Job did not finish in time") # Get the documents - documents = client.list_documents_from_source(source.id) - assert len(documents) == 13 # 13 pages + documents = client.list_files_from_source(source.id) + assert len(documents) == 1 # Should be condensed to one document # Get the memgpt paper document = documents[0] diff --git a/tests/utils.py b/tests/utils.py index f1c6cbfc02..8323c68831 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector -from letta.schemas.document import Document +from letta.schemas.file import File from letta.settings import TestSettings from .constants import TIMEOUT @@ -19,13 +19,12 @@ class DummyDataConnector(DataConnector): def __init__(self, texts: List[str]): self.texts = texts - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: + def generate_files(self) -> Iterator[Tuple[str, Dict]]: for text in self.texts: yield text, {"metadata": "dummy"} - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: - for doc in documents: - yield doc.text, doc.metadata_ + def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + yield file_text, file.metadata_ def wipe_config(): From 08862a013dc8298d425fbda88955c520cf8ade0d Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 15:10:27 -0700 Subject: [PATCH 09/23] Respond to comments --- letta/agent_store/db.py | 10 +++++----- letta/agent_store/storage.py | 2 +- letta/config.py | 5 ----- letta/data_sources/connectors.py | 1 - 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index b1af67a936..5efa080975 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -372,10 +372,10 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) if self.config.recall_storage_uri is None: raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") elif table_type == TableType.FILES: - self.uri = self.config.documents_storage_uri + self.uri = self.config.metadata_storage_uri self.db_model = FileModel - if self.config.documents_storage_uri is None: - raise ValueError(f"Must specify documents_storage_uri in config {self.config.config_path}") + if self.config.metadata_storage_uri is None: + raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}") else: raise ValueError(f"Table type {table_type} not implemented") @@ -495,9 +495,9 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) raise ValueError(f"Must specify recall_storage_path in config.") self.db_model = MessageModel elif table_type == TableType.FILES: - self.path = self.config.documents_storage_path + self.path = self.config.metadata_storage_path if self.path is None: - raise ValueError(f"Must specify documents_storage_path in config.") + raise ValueError(f"Must specify metadata_storage_path in config.") self.db_model = FileModel else: diff --git a/letta/agent_store/storage.py b/letta/agent_store/storage.py index a2092e1c70..9f953dc3d9 100644 --- a/letta/agent_store/storage.py +++ b/letta/agent_store/storage.py @@ -93,7 +93,7 @@ def get_storage_connector( elif table_type == TableType.RECALL_MEMORY: storage_type = config.recall_storage_type elif table_type == TableType.FILES: - storage_type = config.documents_storage_type + storage_type = config.metadata_storage_type else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/letta/config.py b/letta/config.py index e03bac9bfd..b07c7f4cbc 100644 --- a/letta/config.py +++ b/letta/config.py @@ -78,11 +78,6 @@ class LettaConfig: metadata_storage_path: str = LETTA_DIR metadata_storage_uri: str = None - # database configs: document storage - documents_storage_type: str = "sqlite" - documents_storage_path: str = LETTA_DIR - documents_storage_uri: str = None - # database configs: agent state persistence_manager_type: str = None # in-memory, db persistence_manager_save_file: str = None # local file diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index a4ab1cb5be..4dddbd6376 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -59,7 +59,6 @@ def load_data( for file_text, document_metadata in connector.generate_files(): # insert document into storage file = File( - id=create_uuid_from_string(f"{str(source.id)}_{file_text}"), user_id=source.user_id, source_id=source.id, metadata_=document_metadata, From 427869b1883570acf5b3b6ffc433cbae61fdb8fb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 10 Oct 2024 16:15:53 -0700 Subject: [PATCH 10/23] Add pagination --- letta/agent_store/db.py | 6 ++--- letta/agent_store/lancedb.py | 4 +-- letta/agent_store/milvus.py | 2 +- letta/agent_store/qdrant.py | 2 +- letta/client/client.py | 29 ++++++++++++++------- letta/data_sources/connectors.py | 2 +- letta/metadata.py | 26 +++++++++++++++--- letta/schemas/file.py | 9 +++++-- letta/schemas/passage.py | 6 ++--- letta/server/rest_api/routers/v1/sources.py | 12 ++++----- letta/server/server.py | 6 ++--- tests/test_client.py | 9 ++++--- 12 files changed, 74 insertions(+), 39 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 5efa080975..a2b7433cce 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -141,7 +141,7 @@ class PassageModel(Base): id = Column(String, primary_key=True) user_id = Column(String, nullable=False) text = Column(String) - doc_id = Column(String) + file_id = Column(String) agent_id = Column(String) source_id = Column(String) @@ -160,7 +160,7 @@ class PassageModel(Base): # Add a datetime column, with default value as the current time created_at = Column(DateTime(timezone=True)) - Index("passage_idx_user", user_id, agent_id, doc_id), + Index("passage_idx_user", user_id, agent_id, file_id), def __repr__(self): return f" Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: - # use llama index to run embeddings code + for metadata in extract_metadata_from_files(files): + yield FileMetadata( + user_id=source.user_id, + source_id=source.id, + file_name=metadata.get("file_name"), + file_path=metadata.get("file_path"), + file_type=metadata.get("file_type"), + file_size=metadata.get("file_size"), + file_creation_date=metadata.get("file_creation_date"), + file_last_modified_date=metadata.get("file_last_modified_date"), + ) + + def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: from llama_index.core.node_parser import TokenTextSplitter + from llama_index.readers.file import FlatReader + + # TODO: Extend to different kinds of files + # This one just reads raw text + docs = FlatReader().load_data(Path(file.file_path)) parser = TokenTextSplitter(chunk_size=chunk_size) - llama_index_docs = [LlamaIndexDocument(text=file_text, metadata=file.metadata_)] - nodes = parser.get_nodes_from_documents(llama_index_docs) + nodes = parser.get_nodes_from_documents(docs) for node in nodes: yield node.text, None -class WebConnector(DirectoryConnector): - def __init__(self, urls: List[str] = None, html_to_text: bool = True): - self.urls = urls - self.html_to_text = html_to_text - - def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: - from llama_index.readers.web import SimpleWebPageReader - - files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls) - for document in files: - yield document.text, {"url": document.id_} - - -class VectorDBConnector(DataConnector): - # NOTE: this class has not been properly tested, so is unlikely to work - # TODO: allow loading multiple tables (1:1 mapping between File and Table) - - def __init__( - self, - name: str, - uri: str, - table_name: str, - text_column: str, - embedding_column: str, - embedding_dim: int, - ): - self.name = name - self.uri = uri - self.table_name = table_name - self.text_column = text_column - self.embedding_column = embedding_column - self.embedding_dim = embedding_dim - - # connect to db table - from sqlalchemy import create_engine - - self.engine = create_engine(uri) - - def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: - yield self.table_name, None - - def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: - from pgvector.sqlalchemy import Vector - from sqlalchemy import Inspector, MetaData, Table, select - - metadata = MetaData() - # Create an inspector to inspect the database - inspector = Inspector.from_engine(self.engine) - table_names = inspector.get_table_names() - assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}." - - table = Table(self.table_name, metadata, autoload_with=self.engine) - - # Prepare a select statement - select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim))) - - # Execute the query and fetch the results - # TODO: paginate results - with self.engine.connect() as connection: - result = connection.execute(select_statement).fetchall() - - for text, embedding in result: - # assume that embeddings are the same model as in config - # TODO: don't re-compute embedding - yield text, {"embedding": embedding} +""" +The below isn't used anywhere, it isn't tested, and pretty much should be deleted. +- Matt +""" +# class WebConnector(DirectoryConnector): +# def __init__(self, urls: List[str] = None, html_to_text: bool = True): +# self.urls = urls +# self.html_to_text = html_to_text +# +# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: +# from llama_index.readers.web import SimpleWebPageReader +# +# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls) +# for document in files: +# yield document.text, {"url": document.id_} +# +# +# class VectorDBConnector(DataConnector): +# # NOTE: this class has not been properly tested, so is unlikely to work +# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table) +# +# def __init__( +# self, +# name: str, +# uri: str, +# table_name: str, +# text_column: str, +# embedding_column: str, +# embedding_dim: int, +# ): +# self.name = name +# self.uri = uri +# self.table_name = table_name +# self.text_column = text_column +# self.embedding_column = embedding_column +# self.embedding_dim = embedding_dim +# +# # connect to db table +# from sqlalchemy import create_engine +# +# self.engine = create_engine(uri) +# +# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: +# yield self.table_name, None +# +# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: +# from pgvector.sqlalchemy import Vector +# from sqlalchemy import Inspector, MetaData, Table, select +# +# metadata = MetaData() +# # Create an inspector to inspect the database +# inspector = Inspector.from_engine(self.engine) +# table_names = inspector.get_table_names() +# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}." +# +# table = Table(self.table_name, metadata, autoload_with=self.engine) +# +# # Prepare a select statement +# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim))) +# +# # Execute the query and fetch the results +# # TODO: paginate results +# with self.engine.connect() as connection: +# result = connection.execute(select_statement).fetchall() +# +# for text, embedding in result: +# # assume that embeddings are the same model as in config +# # TODO: don't re-compute embedding +# yield text, {"embedding": embedding} diff --git a/letta/data_sources/connectors_helper.py b/letta/data_sources/connectors_helper.py new file mode 100644 index 0000000000..9d32e47255 --- /dev/null +++ b/letta/data_sources/connectors_helper.py @@ -0,0 +1,97 @@ +import mimetypes +import os +from datetime import datetime +from pathlib import Path +from typing import List, Optional + + +def extract_file_metadata(file_path) -> dict: + """Extracts metadata from a single file.""" + if not os.path.exists(file_path): + raise FileNotFoundError(file_path) + + file_metadata = { + "file_name": os.path.basename(file_path), + "file_path": file_path, + "file_type": mimetypes.guess_type(file_path)[0] or "unknown", + "file_size": os.path.getsize(file_path), + "file_creation_date": datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d"), + "file_last_modified_date": datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d"), + } + return file_metadata + + +def extract_metadata_from_files(file_list): + """Extracts metadata for a list of files.""" + metadata = [] + for file_path in file_list: + file_metadata = extract_file_metadata(file_path) + if file_metadata: + metadata.append(file_metadata) + return metadata + + +def get_filenames_in_dir( + input_dir: str, recursive: bool = True, required_exts: Optional[List[str]] = None, exclude: Optional[List[str]] = None +): + """ + Recursively reads files from the directory, applying required_exts and exclude filters. + Ensures that required_exts and exclude do not overlap. + + Args: + input_dir (str): The directory to scan for files. + recursive (bool): Whether to scan directories recursively. + required_exts (list): List of file extensions to include (e.g., ['pdf', 'txt']). + If None or empty, matches any file extension. + exclude (list): List of file patterns to exclude (e.g., ['*png', '*jpg']). + + Returns: + list: A list of matching file paths. + """ + required_exts = required_exts or [] + exclude = exclude or [] + + # Ensure required_exts and exclude do not overlap + ext_set = set(required_exts) + exclude_set = set(exclude) + overlap = ext_set & exclude_set + if overlap: + raise ValueError(f"Extensions in required_exts and exclude overlap: {overlap}") + + def is_excluded(file_name): + """Check if a file matches any pattern in the exclude list.""" + for pattern in exclude: + if Path(file_name).match(pattern): + return True + return False + + files = [] + search_pattern = "**/*" if recursive else "*" + + for file_path in Path(input_dir).glob(search_pattern): + if file_path.is_file() and not is_excluded(file_path.name): + ext = file_path.suffix.lstrip(".") + # If required_exts is empty, match any file + if not required_exts or ext in required_exts: + files.append(file_path) + + return files + + +def assert_all_files_exist_locally(file_paths: List[str]) -> bool: + """ + Checks if all file paths in the provided list exist locally. + Raises a FileNotFoundError with a list of missing files if any do not exist. + + Args: + file_paths (List[str]): List of file paths to check. + + Returns: + bool: True if all files exist, raises FileNotFoundError if any file is missing. + """ + missing_files = [file_path for file_path in file_paths if not Path(file_path).exists()] + + if missing_files: + raise FileNotFoundError(missing_files) + + return True diff --git a/letta/metadata.py b/letta/metadata.py index 7b0ead505f..87fa9ccf2b 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -11,6 +11,7 @@ Column, DateTime, Index, + Integer, String, TypeDecorator, desc, @@ -25,7 +26,7 @@ from letta.schemas.block import Block, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus -from letta.schemas.file import File, PaginatedListFilesResponse +from letta.schemas.file import FileMetadata, PaginatedListFilesResponse from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -49,18 +50,28 @@ class FileModel(Base): # TODO: Investigate why this breaks during table creation due to FK # source_id = Column(String, ForeignKey("sources.id"), nullable=False) source_id = Column(String, nullable=False) - metadata_ = Column(JSON, nullable=True) # Any additional metadata + file_name = Column(String, nullable=True) + file_path = Column(String, nullable=True) + file_type = Column(String, nullable=True) + file_size = Column(Integer, nullable=True) + file_creation_date = Column(String, nullable=True) + file_last_modified_date = Column(String, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) def __repr__(self): - return f"" + return f"" def to_record(self): - return File( + return FileMetadata( id=self.id, user_id=self.user_id, source_id=self.source_id, - metadata_=self.metadata_, + file_name=self.file_name, + file_path=self.file_path, + file_type=self.file_type, + file_size=self.file_size, + file_creation_date=self.file_creation_date, + file_last_modified_date=self.file_last_modified_date, created_at=self.created_at, ) @@ -908,7 +919,7 @@ def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[st # Limit the number of results returned results = query.limit(limit).all() - # Convert the results to the required File objects + # Convert the results to the required FileMetadata objects files = [r.to_record() for r in results] # Generate the next cursor from the last item in the current result set diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 404727d567..141228f7a8 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Optional +from typing import List, Optional from pydantic import BaseModel, Field @@ -13,19 +13,24 @@ class FileBase(LettaBase): __id_prefix__ = "doc" -class File(FileBase): - """Representation of a single document (broken up into `Passage` objects)""" +class FileMetadata(FileBase): + """Representation of a single FileMetadata (broken up into `Passage` objects)""" id: str = FileBase.generate_id_field() user_id: str = Field(description="The unique identifier of the user associated with the document.") source_id: str = Field(..., description="The unique identifier of the source associated with the document.") - metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") - created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.") + file_name: Optional[str] = Field(None, description="The name of the file.") + file_path: Optional[str] = Field(None, description="The path to the file.") + file_type: Optional[str] = Field(None, description="The type of the file (MIME type).") + file_size: Optional[int] = Field(None, description="The size of the file in bytes.") + file_creation_date: Optional[str] = Field(None, description="The creation date of the file.") + file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.") + created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.") class Config: extra = "allow" class PaginatedListFilesResponse(BaseModel): - files: List[File] + files: List[FileMetadata] next_cursor: Optional[str] = None # The cursor for fetching the next page, if any diff --git a/tests/utils.py b/tests/utils.py index 8323c68831..73c322f47f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector -from letta.schemas.file import File +from letta.schemas.file import FileMetadata from letta.settings import TestSettings from .constants import TIMEOUT @@ -23,7 +23,7 @@ def generate_files(self) -> Iterator[Tuple[str, Dict]]: for text in self.texts: yield text, {"metadata": "dummy"} - def generate_passages(self, file_text: str, file: File, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: yield file_text, file.metadata_ From bb04cea6ec7b8831bafb1eefaaac55cac0ac73dd Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 11:41:32 -0700 Subject: [PATCH 15/23] fix tests --- letta/data_sources/connectors.py | 10 +++------- tests/test_client.py | 6 +++--- tests/utils.py | 23 +++++++++++++++++------ 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index d4881e45ae..f729c8ad10 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Dict, Iterator, List, Tuple import typer @@ -167,15 +166,12 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]: ) def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: + from llama_index.core import SimpleDirectoryReader from llama_index.core.node_parser import TokenTextSplitter - from llama_index.readers.file import FlatReader - - # TODO: Extend to different kinds of files - # This one just reads raw text - docs = FlatReader().load_data(Path(file.file_path)) parser = TokenTextSplitter(chunk_size=chunk_size) - nodes = parser.get_nodes_from_documents(docs) + documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data() + nodes = parser.get_nodes_from_documents(documents) for node in nodes: yield node.text, None diff --git a/tests/test_client.py b/tests/test_client.py index c19b430c5c..7d707d670e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -359,9 +359,9 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): assert len(files) == 1 # Should be condensed to one document # Get the memgpt paper - document = files[0] - assert document.metadata_.get("file_name", None) == "memgpt_paper.pdf" - assert document.source_id == source.id + file = files[0] + assert file.file_name == "memgpt_paper.pdf" + assert file.source_id == source.id def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): diff --git a/tests/utils.py b/tests/utils.py index 73c322f47f..abffe841ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import datetime import os +from datetime import datetime from importlib import util from typing import Dict, Iterator, List, Tuple @@ -19,12 +20,22 @@ class DummyDataConnector(DataConnector): def __init__(self, texts: List[str]): self.texts = texts - def generate_files(self) -> Iterator[Tuple[str, Dict]]: - for text in self.texts: - yield text, {"metadata": "dummy"} - - def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: - yield file_text, file.metadata_ + def find_files(self, source) -> Iterator[FileMetadata]: + for _ in self.texts: + yield FileMetadata( + user_id="", + source_id="", + file_name="", + file_path="", + file_type="", + file_size=0, # Set to 0 as a placeholder + file_creation_date="1970-01-01", # Placeholder date + file_last_modified_date="1970-01-01", # Placeholder date + created_at=datetime.utcnow(), + ) + + def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + yield "test", {} def wipe_config(): From 1d289009ee48618fafae31a15e7992a535b47ae3 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 13:42:27 -0700 Subject: [PATCH 16/23] Adapt dummy connector --- tests/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index abffe841ec..2168e2e387 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,10 +19,11 @@ class DummyDataConnector(DataConnector): def __init__(self, texts: List[str]): self.texts = texts + self.file_to_text = {} def find_files(self, source) -> Iterator[FileMetadata]: - for _ in self.texts: - yield FileMetadata( + for text in self.texts: + file_metadata = FileMetadata( user_id="", source_id="", file_name="", @@ -33,9 +34,12 @@ def find_files(self, source) -> Iterator[FileMetadata]: file_last_modified_date="1970-01-01", # Placeholder date created_at=datetime.utcnow(), ) + self.file_to_text[file_metadata.id] = text + + yield file_metadata def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: - yield "test", {} + yield self.file_to_text[file.id], {} def wipe_config(): From ecdccc63df2d3d3abf91de2fb868073fe623db51 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 13:43:37 -0700 Subject: [PATCH 17/23] rewind changes --- docs/data_sources.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/data_sources.md b/docs/data_sources.md index a0567207c6..a0d74b31f8 100644 --- a/docs/data_sources.md +++ b/docs/data_sources.md @@ -102,7 +102,7 @@ class DummyDataConnector(DataConnector): for text in self.texts: yield text, {"metadata": "dummy"} - def generate_passages(self, document_text: str, documents: File, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: for doc in documents: - yield document_text, doc.metadata + yield doc.text, doc.metadata ``` From 03c8fa42017e74f1f959511ad13cd698dd32fa76 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 17:42:34 -0700 Subject: [PATCH 18/23] Adjust comments --- letta/schemas/file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 3a0f8962f6..37e392545f 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -8,13 +8,13 @@ class FileMetadataBase(LettaBase): - """Base class for document schemas""" + """Base class for FileMetadata schemas""" __id_prefix__ = "file" class FileMetadata(FileMetadataBase): - """Representation of a single FileMetadata (broken up into `Passage` objects)""" + """Representation of a single FileMetadata""" id: str = FileMetadataBase.generate_id_field() user_id: str = Field(description="The unique identifier of the user associated with the document.") From 41601d00981b20f0773ae7097f34fb5e771ae1d0 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 17:50:08 -0700 Subject: [PATCH 19/23] fix-tests --- tests/test_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 9dfa44f039..3b606153f5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -354,8 +354,7 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): upload_file_using_client(client, source, filename) # Get the files - list_files_response = client.list_files_from_source(source.id) - files = list_files_response.files + files = client.list_files_from_source(source.id) assert len(files) == 1 # Should be condensed to one document # Get the memgpt paper From 6c5554054306467d4c826527062d4565a0351a62 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 18:08:48 -0700 Subject: [PATCH 20/23] run alembic --- ...7c830b_add_filemetadata_table_and_drop_.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py diff --git a/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py b/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py new file mode 100644 index 0000000000..7fb2b1ea7a --- /dev/null +++ b/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py @@ -0,0 +1,31 @@ +"""Add FileMetadata table and drop Documents table + +Revision ID: 7fb2327c830b +Revises: 9a505cc7eca9 +Create Date: 2024-10-11 18:07:21.765653 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "7fb2327c830b" +down_revision: Union[str, None] = "9a505cc7eca9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("passages", "embedding", existing_type=sa.NUMERIC(), type_=letta.agent_store.db.CommonVector(), existing_nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("passages", "embedding", existing_type=letta.agent_store.db.CommonVector(), type_=sa.NUMERIC(), existing_nullable=True) + # ### end Alembic commands ### From dc5ee65c244a2673aeb807eefbaac20463987eb2 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 11 Oct 2024 18:10:42 -0700 Subject: [PATCH 21/23] fix alembic autogen --- .../7fb2327c830b_add_filemetadata_table_and_drop_.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py b/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py index 7fb2b1ea7a..59ac6d752b 100644 --- a/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py +++ b/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py @@ -11,6 +11,7 @@ import sqlalchemy as sa from alembic import op +from letta.agent_store.db import CommonVector # revision identifiers, used by Alembic. revision: str = "7fb2327c830b" @@ -21,11 +22,11 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("passages", "embedding", existing_type=sa.NUMERIC(), type_=letta.agent_store.db.CommonVector(), existing_nullable=True) + op.alter_column("passages", "embedding", existing_type=sa.NUMERIC(), type_=CommonVector(), existing_nullable=True) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("passages", "embedding", existing_type=letta.agent_store.db.CommonVector(), type_=sa.NUMERIC(), existing_nullable=True) + op.alter_column("passages", "embedding", existing_type=CommonVector(), type_=sa.NUMERIC(), existing_nullable=True) # ### end Alembic commands ### From acfc0394e336e0b2ad322f61e40f3937be2425c4 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 14 Oct 2024 10:05:15 -0700 Subject: [PATCH 22/23] fix alembic --- ...7c830b_add_filemetadata_table_and_drop_.py | 32 ------------------- 1 file changed, 32 deletions(-) delete mode 100644 alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py diff --git a/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py b/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py deleted file mode 100644 index 59ac6d752b..0000000000 --- a/alembic/versions/7fb2327c830b_add_filemetadata_table_and_drop_.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Add FileMetadata table and drop Documents table - -Revision ID: 7fb2327c830b -Revises: 9a505cc7eca9 -Create Date: 2024-10-11 18:07:21.765653 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa - -from alembic import op -from letta.agent_store.db import CommonVector - -# revision identifiers, used by Alembic. -revision: str = "7fb2327c830b" -down_revision: Union[str, None] = "9a505cc7eca9" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("passages", "embedding", existing_type=sa.NUMERIC(), type_=CommonVector(), existing_nullable=True) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("passages", "embedding", existing_type=CommonVector(), type_=sa.NUMERIC(), existing_nullable=True) - # ### end Alembic commands ### From aff27dfa8c240b91ab745501c58a3379d849fc3a Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 14 Oct 2024 10:09:10 -0700 Subject: [PATCH 23/23] refactor to filemetadata model --- letta/agent_store/db.py | 6 +++--- letta/metadata.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 15da6d62c0..5e4fc5ae33 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -28,7 +28,7 @@ from letta.base import Base from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.metadata import EmbeddingConfigColumn, FileModel, ToolCallColumn +from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message @@ -373,7 +373,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") elif table_type == TableType.FILES: self.uri = self.config.metadata_storage_uri - self.db_model = FileModel + self.db_model = FileMetadataModel if self.config.metadata_storage_uri is None: raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}") else: @@ -498,7 +498,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) self.path = self.config.metadata_storage_path if self.path is None: raise ValueError(f"Must specify metadata_storage_path in config.") - self.db_model = FileModel + self.db_model = FileMetadataModel else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/letta/metadata.py b/letta/metadata.py index 7aa46318bb..87473bab50 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -40,7 +40,7 @@ from letta.utils import enforce_types, get_utc_time, printd -class FileModel(Base): +class FileMetadataModel(Base): __tablename__ = "files" __table_args__ = {"extend_existing": True} @@ -906,14 +906,14 @@ def create_job(self, job: Job): def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]): with self.session_maker() as session: # Start with the basic query filtered by source_id - query = session.query(FileModel).filter(FileModel.source_id == source_id) + query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id) if cursor: # Assuming cursor is the ID of the last file in the previous page - query = query.filter(FileModel.id > cursor) + query = query.filter(FileMetadataModel.id > cursor) # Order by ID or other ordering criteria to ensure correct pagination - query = query.order_by(FileModel.id) + query = query.order_by(FileMetadataModel.id) # Limit the number of results returned results = query.limit(limit).all()