diff --git a/.gitignore b/.gitignore index ddcdb97a72..1fcffd8a4c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection # Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection +openapi_letta.json +openapi_openai.json + ### Eclipse ### .metadata bin/ diff --git a/docs/generate_docs.py b/docs/generate_docs.py index c315ad4ce5..8db3bfb20e 100644 --- a/docs/generate_docs.py +++ b/docs/generate_docs.py @@ -70,7 +70,7 @@ def generate_modules(config): "Message", "Passage", "AgentState", - "Document", + "File", "Source", "LLMConfig", "EmbeddingConfig", diff --git a/examples/notebooks/data_connector.ipynb b/examples/notebooks/data_connector.ipynb index bd26987a66..ca67d280bb 100644 --- a/examples/notebooks/data_connector.ipynb +++ b/examples/notebooks/data_connector.ipynb @@ -270,7 +270,7 @@ "outputs": [], "source": [ "from letta.data_sources.connectors import DataConnector \n", - "from letta.schemas.document import Document\n", + "from letta.schemas.file import FileMetadata\n", "from llama_index.core import Document as LlamaIndexDocument\n", "from llama_index.core import SummaryIndex\n", "from llama_index.readers.web import SimpleWebPageReader\n", diff --git a/letta/__init__.py b/letta/__init__.py index bc2004172b..bb13c28144 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -7,9 +7,9 @@ # imports for easier access from letta.schemas.agent import AgentState from letta.schemas.block import Block -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus +from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index ff22af8c4d..5e4fc5ae33 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -28,7 +28,7 @@ from letta.base import Base from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.metadata import EmbeddingConfigColumn, ToolCallColumn +from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message @@ -141,7 +141,7 @@ class PassageModel(Base): id = Column(String, primary_key=True) user_id = Column(String, nullable=False) text = Column(String) - doc_id = Column(String) + file_id = Column(String) agent_id = Column(String) source_id = Column(String) @@ -160,7 +160,7 @@ class PassageModel(Base): # Add a datetime column, with default value as the current time created_at = Column(DateTime(timezone=True)) - Index("passage_idx_user", user_id, agent_id, doc_id), + Index("passage_idx_user", user_id, agent_id, file_id), def __repr__(self): return f" Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: +# from llama_index.readers.web import SimpleWebPageReader +# +# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls) +# for document in files: +# yield document.text, {"url": document.id_} +# +# +# class VectorDBConnector(DataConnector): +# # NOTE: this class has not been properly tested, so is unlikely to work +# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table) +# +# def __init__( +# self, +# name: str, +# uri: str, +# table_name: str, +# text_column: str, +# embedding_column: str, +# embedding_dim: int, +# ): +# self.name = name +# self.uri = uri +# self.table_name = table_name +# self.text_column = text_column +# self.embedding_column = embedding_column +# self.embedding_dim = embedding_dim +# +# # connect to db table +# from sqlalchemy import create_engine +# +# self.engine = create_engine(uri) +# +# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]: +# yield self.table_name, None +# +# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]: +# from pgvector.sqlalchemy import Vector +# from sqlalchemy import Inspector, MetaData, Table, select +# +# metadata = MetaData() +# # Create an inspector to inspect the database +# inspector = Inspector.from_engine(self.engine) +# table_names = inspector.get_table_names() +# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}." +# +# table = Table(self.table_name, metadata, autoload_with=self.engine) +# +# # Prepare a select statement +# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim))) +# +# # Execute the query and fetch the results +# # TODO: paginate results +# with self.engine.connect() as connection: +# result = connection.execute(select_statement).fetchall() +# +# for text, embedding in result: +# # assume that embeddings are the same model as in config +# # TODO: don't re-compute embedding +# yield text, {"embedding": embedding} diff --git a/letta/data_sources/connectors_helper.py b/letta/data_sources/connectors_helper.py new file mode 100644 index 0000000000..9d32e47255 --- /dev/null +++ b/letta/data_sources/connectors_helper.py @@ -0,0 +1,97 @@ +import mimetypes +import os +from datetime import datetime +from pathlib import Path +from typing import List, Optional + + +def extract_file_metadata(file_path) -> dict: + """Extracts metadata from a single file.""" + if not os.path.exists(file_path): + raise FileNotFoundError(file_path) + + file_metadata = { + "file_name": os.path.basename(file_path), + "file_path": file_path, + "file_type": mimetypes.guess_type(file_path)[0] or "unknown", + "file_size": os.path.getsize(file_path), + "file_creation_date": datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d"), + "file_last_modified_date": datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d"), + } + return file_metadata + + +def extract_metadata_from_files(file_list): + """Extracts metadata for a list of files.""" + metadata = [] + for file_path in file_list: + file_metadata = extract_file_metadata(file_path) + if file_metadata: + metadata.append(file_metadata) + return metadata + + +def get_filenames_in_dir( + input_dir: str, recursive: bool = True, required_exts: Optional[List[str]] = None, exclude: Optional[List[str]] = None +): + """ + Recursively reads files from the directory, applying required_exts and exclude filters. + Ensures that required_exts and exclude do not overlap. + + Args: + input_dir (str): The directory to scan for files. + recursive (bool): Whether to scan directories recursively. + required_exts (list): List of file extensions to include (e.g., ['pdf', 'txt']). + If None or empty, matches any file extension. + exclude (list): List of file patterns to exclude (e.g., ['*png', '*jpg']). + + Returns: + list: A list of matching file paths. + """ + required_exts = required_exts or [] + exclude = exclude or [] + + # Ensure required_exts and exclude do not overlap + ext_set = set(required_exts) + exclude_set = set(exclude) + overlap = ext_set & exclude_set + if overlap: + raise ValueError(f"Extensions in required_exts and exclude overlap: {overlap}") + + def is_excluded(file_name): + """Check if a file matches any pattern in the exclude list.""" + for pattern in exclude: + if Path(file_name).match(pattern): + return True + return False + + files = [] + search_pattern = "**/*" if recursive else "*" + + for file_path in Path(input_dir).glob(search_pattern): + if file_path.is_file() and not is_excluded(file_path.name): + ext = file_path.suffix.lstrip(".") + # If required_exts is empty, match any file + if not required_exts or ext in required_exts: + files.append(file_path) + + return files + + +def assert_all_files_exist_locally(file_paths: List[str]) -> bool: + """ + Checks if all file paths in the provided list exist locally. + Raises a FileNotFoundError with a list of missing files if any do not exist. + + Args: + file_paths (List[str]): List of file paths to check. + + Returns: + bool: True if all files exist, raises FileNotFoundError if any file is missing. + """ + missing_files = [file_path for file_path in file_paths if not Path(file_path).exists()] + + if missing_files: + raise FileNotFoundError(missing_files) + + return True diff --git a/letta/metadata.py b/letta/metadata.py index c8f206f349..87473bab50 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -11,6 +11,7 @@ Column, DateTime, Index, + Integer, String, TypeDecorator, desc, @@ -24,6 +25,7 @@ from letta.schemas.block import Block, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus +from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -38,6 +40,41 @@ from letta.utils import enforce_types, get_utc_time, printd +class FileMetadataModel(Base): + __tablename__ = "files" + __table_args__ = {"extend_existing": True} + + id = Column(String, primary_key=True, nullable=False) + user_id = Column(String, nullable=False) + # TODO: Investigate why this breaks during table creation due to FK + # source_id = Column(String, ForeignKey("sources.id"), nullable=False) + source_id = Column(String, nullable=False) + file_name = Column(String, nullable=True) + file_path = Column(String, nullable=True) + file_type = Column(String, nullable=True) + file_size = Column(Integer, nullable=True) + file_creation_date = Column(String, nullable=True) + file_last_modified_date = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + def __repr__(self): + return f"" + + def to_record(self): + return FileMetadata( + id=self.id, + user_id=self.user_id, + source_id=self.source_id, + file_name=self.file_name, + file_path=self.file_path, + file_type=self.file_type, + file_size=self.file_size, + file_creation_date=self.file_creation_date, + file_last_modified_date=self.file_last_modified_date, + created_at=self.created_at, + ) + + class LLMConfigColumn(TypeDecorator): """Custom type for storing LLMConfig as JSON""" @@ -865,6 +902,27 @@ def create_job(self, job: Job): session.add(JobModel(**vars(job))) session.commit() + @enforce_types + def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]): + with self.session_maker() as session: + # Start with the basic query filtered by source_id + query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id) + + if cursor: + # Assuming cursor is the ID of the last file in the previous page + query = query.filter(FileMetadataModel.id > cursor) + + # Order by ID or other ordering criteria to ensure correct pagination + query = query.order_by(FileMetadataModel.id) + + # Limit the number of results returned + results = query.limit(limit).all() + + # Convert the results to the required FileMetadata objects + files = [r.to_record() for r in results] + + return files + def delete_job(self, job_id: str): with self.session_maker() as session: session.query(JobModel).filter(JobModel.id == job_id).delete() diff --git a/letta/schemas/document.py b/letta/schemas/document.py deleted file mode 100644 index 2628ac5702..0000000000 --- a/letta/schemas/document.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Dict, Optional - -from pydantic import Field - -from letta.schemas.letta_base import LettaBase - - -class DocumentBase(LettaBase): - """Base class for document schemas""" - - __id_prefix__ = "doc" - - -class Document(DocumentBase): - """Representation of a single document (broken up into `Passage` objects)""" - - id: str = DocumentBase.generate_id_field() - text: str = Field(..., description="The text of the document.") - source_id: str = Field(..., description="The unique identifier of the source associated with the document.") - user_id: str = Field(description="The unique identifier of the user associated with the document.") - metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") diff --git a/letta/schemas/file.py b/letta/schemas/file.py new file mode 100644 index 0000000000..37e392545f --- /dev/null +++ b/letta/schemas/file.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase +from letta.utils import get_utc_time + + +class FileMetadataBase(LettaBase): + """Base class for FileMetadata schemas""" + + __id_prefix__ = "file" + + +class FileMetadata(FileMetadataBase): + """Representation of a single FileMetadata""" + + id: str = FileMetadataBase.generate_id_field() + user_id: str = Field(description="The unique identifier of the user associated with the document.") + source_id: str = Field(..., description="The unique identifier of the source associated with the document.") + file_name: Optional[str] = Field(None, description="The name of the file.") + file_path: Optional[str] = Field(None, description="The path to the file.") + file_type: Optional[str] = Field(None, description="The type of the file (MIME type).") + file_size: Optional[int] = Field(None, description="The size of the file in bytes.") + file_creation_date: Optional[str] = Field(None, description="The creation date of the file.") + file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.") + created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.") + + class Config: + extra = "allow" diff --git a/letta/schemas/job.py b/letta/schemas/job.py index da83d4be3d..4499c167cd 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -15,7 +15,7 @@ class JobBase(LettaBase): class Job(JobBase): """ - Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding documents). + Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding files). Parameters: id (str): The unique identifier of the job. diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index bc3f05e2f2..2ecc5e9ac3 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -19,8 +19,8 @@ class PassageBase(LettaBase): # origin data source source_id: Optional[str] = Field(None, description="The data source of the passage.") - # document association - doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.") + # file association + file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.") metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.") @@ -36,7 +36,7 @@ class Passage(PassageBase): user_id (str): The unique identifier of the user associated with the passage. agent_id (str): The unique identifier of the agent associated with the passage. source_id (str): The data source of the passage. - doc_id (str): The unique identifier of the document associated with the passage. + file_id (str): The unique identifier of the file associated with the passage. """ id: str = PassageBase.generate_id_field() diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 827ebb9f51..8f816ad701 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -28,7 +28,7 @@ class SourceCreate(BaseSource): class Source(BaseSource): """ - Representation of a source, which is a collection of documents and passages. + Representation of a source, which is a collection of files and passages. Parameters: id (str): The ID of the source @@ -59,4 +59,4 @@ class UploadFileToSourceRequest(BaseModel): class UploadFileToSourceResponse(BaseModel): source: Source = Field(..., description="The source the file was uploaded to.") added_passages: int = Field(..., description="The number of passages added to the source.") - added_documents: int = Field(..., description="The number of documents added to the source.") + added_documents: int = Field(..., description="The number of files added to the source.") diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index bd581a98ef..3f3fef17cb 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, Header, Query +from fastapi import APIRouter, Depends, Header, HTTPException, Query from letta.schemas.job import Job from letta.server.rest_api.utils import get_letta_server @@ -54,3 +54,19 @@ def get_job( """ return server.get_job(job_id=job_id) + + +@router.delete("/{job_id}", response_model=Job, operation_id="delete_job") +def delete_job( + job_id: str, + server: "SyncServer" = Depends(get_letta_server), +): + """ + Delete a job by its job_id. + """ + job = server.get_job(job_id=job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + server.delete_job(job_id=job_id) + return job diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index a59abd31bb..c25206b458 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile -from letta.schemas.document import Document +from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -186,19 +186,17 @@ def list_passages( return passages -@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_source_documents") -def list_documents( +@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_files_from_source") +def list_files_from_source( source_id: str, + limit: int = Query(1000, description="Number of files to return"), + cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - List all documents associated with a data source. + List paginated files associated with a data source. """ - actor = server.get_user_or_default(user_id=user_id) - - documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id) - return documents + return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor) def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): diff --git a/letta/server/server.py b/letta/server/server.py index 08050ac080..987d612660 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -63,11 +63,11 @@ CreatePersona, UpdateBlock, ) -from letta.schemas.document import Document from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus +from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig @@ -1533,7 +1533,7 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Jo # job.status = JobStatus.failed # job.metadata_["error"] = error # self.ms.update_job(job) - # # TODO: delete any associated passages/documents? + # # TODO: delete any associated passages/files? # # return failed job # return job @@ -1562,11 +1562,10 @@ def load_data( # get the data connectors passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) - # TODO: add document store support - document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) # load data into the document store - passage_count, document_count = load_data(connector, source, passage_store, document_store) + passage_count, document_count = load_data(connector, source, passage_store, file_store) return passage_count, document_count def attach_source_to_agent( @@ -1623,14 +1622,14 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: # list all attached sources to an agent return self.ms.list_attached_sources(agent_id) + def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]: + # list all attached sources to an agent + return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor) + def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]: - warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning) - return [] - def list_all_sources(self, user_id: str) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" @@ -1644,9 +1643,9 @@ def list_all_sources(self, user_id: str) -> List[Source]: passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) num_passages = passage_conn.size({"source_id": source.id}) - # TODO: add when documents table implemented - ## count number of documents - # document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + # TODO: add when files table implemented + ## count number of files + # document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) # num_documents = document_conn.size({"data_source": source.name}) num_documents = 0 diff --git a/tests/data/test.txt b/tests/data/test.txt new file mode 100644 index 0000000000..30d74d2584 --- /dev/null +++ b/tests/data/test.txt @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/tests/helpers/client_helper.py b/tests/helpers/client_helper.py new file mode 100644 index 0000000000..feff9e6b6d --- /dev/null +++ b/tests/helpers/client_helper.py @@ -0,0 +1,34 @@ +import time +from typing import Union + +from letta import LocalClient, RESTClient +from letta.schemas.enums import JobStatus +from letta.schemas.job import Job +from letta.schemas.source import Source + + +def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job: + # load a file into a source (non-blocking job) + upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False) + print("Upload job", upload_job, upload_job.status, upload_job.metadata_) + + # view active jobs + active_jobs = client.list_active_jobs() + jobs = client.list_jobs() + assert upload_job.id in [j.id for j in jobs] + assert len(active_jobs) == 1 + assert active_jobs[0].metadata_["source_id"] == source.id + + # wait for job to finish (with timeout) + timeout = 120 + start_time = time.time() + while True: + status = client.get_job(upload_job.id).status + print(f"\r{status}", end="", flush=True) + if status == JobStatus.completed: + break + time.sleep(1) + if time.time() - start_time > timeout: + raise ValueError("Job did not finish in time") + + return upload_job diff --git a/tests/test_client.py b/tests/test_client.py index fe3e581544..3b606153f5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,12 +12,13 @@ from letta.constants import DEFAULT_PRESET from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import JobStatus, MessageStreamStatus +from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics +from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -298,6 +299,70 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): # print("CONFIG", config_response) +def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load files into sources + file_a = "tests/data/memgpt_paper.pdf" + file_b = "tests/data/test.txt" + upload_file_using_client(client, source, file_a) + upload_file_using_client(client, source, file_b) + + # Get the first file + files_a = client.list_files_from_source(source.id, limit=1) + assert len(files_a) == 1 + assert files_a[0].source_id == source.id + + # Use the cursor from response_a to get the remaining file + files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id) + assert len(files_b) == 1 + assert files_b[0].source_id == source.id + + # Check files are different to ensure the cursor works + assert files_a[0].file_name != files_b[0].file_name + + # Use the cursor from response_b to list files, should be empty + files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id) + assert len(files) == 0 # Should be empty + + +def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load a file into a source (non-blocking job) + filename = "tests/data/memgpt_paper.pdf" + upload_file_using_client(client, source, filename) + + # Get the files + files = client.list_files_from_source(source.id) + assert len(files) == 1 # Should be condensed to one document + + # Get the memgpt paper + file = files[0] + assert file.file_name == "memgpt_paper.pdf" + assert file.source_id == source.id + + def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() @@ -305,6 +370,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): for source in client.list_sources(): client.delete_source(source.id) + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + # list sources sources = client.list_sources() print("listed sources", sources) @@ -343,28 +412,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # load a file into a source (non-blocking job) filename = "tests/data/memgpt_paper.pdf" - upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False) - print("Upload job", upload_job, upload_job.status, upload_job.metadata_) - - # view active jobs - active_jobs = client.list_active_jobs() - jobs = client.list_jobs() - print(jobs) - assert upload_job.id in [j.id for j in jobs] - assert len(active_jobs) == 1 - assert active_jobs[0].metadata_["source_id"] == source.id - - # wait for job to finish (with timeout) - timeout = 120 - start_time = time.time() - while True: - status = client.get_job(upload_job.id).status - print(status) - if status == JobStatus.completed: - break - time.sleep(1) - if time.time() - start_time > timeout: - raise ValueError("Job did not finish in time") + upload_job = upload_file_using_client(client, source, filename) job = client.get_job(upload_job.id) created_passages = job.metadata_["num_passages"] diff --git a/tests/utils.py b/tests/utils.py index f1c6cbfc02..2168e2e387 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import datetime import os +from datetime import datetime from importlib import util from typing import Dict, Iterator, List, Tuple @@ -7,7 +8,7 @@ from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector -from letta.schemas.document import Document +from letta.schemas.file import FileMetadata from letta.settings import TestSettings from .constants import TIMEOUT @@ -18,14 +19,27 @@ class DummyDataConnector(DataConnector): def __init__(self, texts: List[str]): self.texts = texts + self.file_to_text = {} - def generate_documents(self) -> Iterator[Tuple[str, Dict]]: + def find_files(self, source) -> Iterator[FileMetadata]: for text in self.texts: - yield text, {"metadata": "dummy"} - - def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: - for doc in documents: - yield doc.text, doc.metadata_ + file_metadata = FileMetadata( + user_id="", + source_id="", + file_name="", + file_path="", + file_type="", + file_size=0, # Set to 0 as a placeholder + file_creation_date="1970-01-01", # Placeholder date + file_last_modified_date="1970-01-01", # Placeholder date + created_at=datetime.utcnow(), + ) + self.file_to_text[file_metadata.id] = text + + yield file_metadata + + def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]: + yield self.file_to_text[file.id], {} def wipe_config():