Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Dec 11, 2024
2 parents f6f0f69 + d0dbb9d commit 92099ac
Show file tree
Hide file tree
Showing 42 changed files with 1,686 additions and 2,291 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code_style_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" # Adjust as necessary
install-args: "-E dev -E postgres -E external-tools -E tests" # Adjust as necessary

- name: Validate PR Title
if: github.event_name == 'pull_request'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-pip-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"] # Adjust Python versions as needed
python-version: ["3.10", "3.11", "3.12", "3.13"] # Adjust Python versions as needed

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_ollama.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E ollama"
install-args: "-E dev"

- name: Test LLM Endpoint
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
install-args: "-E dev -E postgres -E external-tools -E tests"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand Down Expand Up @@ -113,7 +113,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,6 @@ memgpy/pytest.ini

## ignore venvs
tests/test_tool_sandbox/restaurant_management_system/venv

## custom scripts
test
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Add Passages ORM, drop legacy passages, cascading deletes for file-passages and user-jobs
Revision ID: c5d964280dff
Revises: a91994b9752f
Create Date: 2024-12-10 15:05:32.335519
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = 'c5d964280dff'
down_revision: Union[str, None] = 'a91994b9752f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True))
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False))
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True))
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True))

# Data migration step:
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE passages
SET organization_id = users.organization_id
FROM users
WHERE passages.user_id = users.id
"""
)

# Set `organization_id` as non-nullable after population
op.alter_column("passages", "organization_id", nullable=False)

op.alter_column('passages', 'text',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('passages', 'embedding_config',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=False)
op.alter_column('passages', 'metadata_',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=False)
op.alter_column('passages', 'created_at',
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False)
op.drop_index('passage_idx_user', table_name='passages')
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
op.drop_column('passages', 'user_id')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, 'passages', type_='foreignkey')
op.drop_constraint(None, 'passages', type_='foreignkey')
op.drop_constraint(None, 'passages', type_='foreignkey')
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False)
op.alter_column('passages', 'created_at',
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=True)
op.alter_column('passages', 'metadata_',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=True)
op.alter_column('passages', 'embedding_config',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=True)
op.alter_column('passages', 'text',
existing_type=sa.VARCHAR(),
nullable=True)
op.drop_column('passages', 'organization_id')
op.drop_column('passages', '_last_updated_by_id')
op.drop_column('passages', '_created_by_id')
op.drop_column('passages', 'is_deleted')
op.drop_column('passages', 'updated_at')
# ### end Alembic commands ###
75 changes: 32 additions & 43 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
from letta.memory import summarize_messages
from letta.metadata import MetadataStore
from letta.orm import User
from letta.schemas.agent import AgentState, AgentStepResponse
Expand All @@ -52,6 +52,7 @@
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
Expand Down Expand Up @@ -85,7 +86,7 @@ def compile_memory_metadata_block(
actor: PydanticUser,
agent_id: str,
memory_edit_timestamp: datetime.datetime,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
Expand All @@ -96,7 +97,7 @@ def compile_memory_metadata_block(
[
f"### Memory [last modified: {timestamp_str}]",
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
Expand All @@ -109,7 +110,7 @@ def compile_system_message(
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
Expand Down Expand Up @@ -138,7 +139,7 @@ def compile_system_message(
actor=actor,
agent_id=agent_id,
memory_edit_timestamp=in_context_memory_last_edit,
archival_memory=archival_memory,
passage_manager=passage_manager,
message_manager=message_manager,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
Expand Down Expand Up @@ -175,7 +176,7 @@ def initialize_message_sequence(
agent_id: str,
memory: Memory,
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
Expand All @@ -184,15 +185,15 @@ def initialize_message_sequence(
memory_edit_timestamp = get_local_time()

# full_system_message = construct_system_with_memory(
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
# )
full_system_message = compile_system_message(
agent_id=agent_id,
system_prompt=system,
in_context_memory=memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=actor,
archival_memory=archival_memory,
passage_manager=passage_manager,
message_manager=message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
Expand Down Expand Up @@ -294,7 +295,7 @@ def __init__(
self.interface = interface

# Create the persistence manager object based on the AgentState info
self.archival_memory = EmbeddingArchivalMemory(agent_state)
self.passage_manager = PassageManager()
self.message_manager = MessageManager()

# State needed for heartbeat pausing
Expand Down Expand Up @@ -325,7 +326,7 @@ def __init__(
agent_id=self.agent_state.id,
memory=self.agent_state.memory,
actor=self.user,
archival_memory=None,
passage_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
Expand All @@ -350,7 +351,7 @@ def __init__(
memory=self.agent_state.memory,
agent_id=self.agent_state.id,
actor=self.user,
archival_memory=None,
passage_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
Expand Down Expand Up @@ -1306,7 +1307,7 @@ def rebuild_system_prompt(self, force=False, update_timestamp=True):
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=self.user,
archival_memory=self.archival_memory,
passage_manager=self.passage_manager,
message_manager=self.message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
Expand Down Expand Up @@ -1371,45 +1372,33 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig):
# TODO: recall memory
raise NotImplementedError()

def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
user = UserManager().get_user_by_id(self.agent_state.user_id)
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
size = source_connector.size(filters)
page_size = 100
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
all_passages = []
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)

# need to associated passage with agent (for filtering)
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)

# regenerate passage ID (avoid duplicates)
# TODO: need to find another solution to the text duplication issue
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")

# insert into agent archival memory
self.archival_memory.storage.insert_many(passages)
all_passages += passages

assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"

# save destination storage
self.archival_memory.storage.save()
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
assert len(agents_passages) == passage_size # sanity check
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"

# attach to agent
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)

total_agent_passages = self.archival_memory.storage.size()
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
# TODO: delete @matt and remove
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)

printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
)

def update_message(self, message_id: str, request: MessageUpdate) -> Message:
Expand Down Expand Up @@ -1565,13 +1554,13 @@ def get_context_window(self) -> ContextWindowOverview:
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
)

num_archival_memory = self.archival_memory.storage.size()
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
actor=self.user,
agent_id=self.agent_state.id,
memory_edit_timestamp=get_utc_time(), # dummy timestamp
archival_memory=self.archival_memory,
passage_manager=self.passage_manager,
message_manager=self.message_manager,
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
Expand All @@ -1597,7 +1586,7 @@ def get_context_window(self) -> ContextWindowOverview:
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
num_archival_memory=num_archival_memory,
num_archival_memory=passage_manager_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information
Expand Down
Loading

0 comments on commit 92099ac

Please sign in to comment.