Skip to content

Commit

Permalink
feat: add user identities (#2446)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 21, 2025
2 parents a678e6d + 54b382c commit d23478f
Show file tree
Hide file tree
Showing 23 changed files with 662 additions and 161 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""update identities unique constraint and properties
Revision ID: 549eff097c71
Revises: a3047a624130
Create Date: 2025-02-20 09:53:42.743105
"""

from typing import Sequence, Union

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "549eff097c71"
down_revision: Union[str, None] = "a3047a624130"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Update unique constraint on identities table
op.drop_constraint("unique_identifier_pid_org_id", "identities", type_="unique")
op.create_unique_constraint(
"unique_identifier_without_project",
"identities",
["identifier_key", "project_id", "organization_id"],
postgresql_nulls_not_distinct=True,
)

# Add properties column to identities table
op.add_column("identities", sa.Column("properties", postgresql.JSONB, nullable=False, server_default="[]"))

# Create identities_agents table for many-to-many relationship
op.create_table(
"identities_agents",
sa.Column("identity_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("identity_id", "agent_id"),
)

# Migrate existing relationships
# First, get existing relationships where identity_id is not null
op.execute(
"""
INSERT INTO identities_agents (identity_id, agent_id)
SELECT DISTINCT identity_id, id as agent_id
FROM agents
WHERE identity_id IS NOT NULL
"""
)

# Remove old identity_id column from agents
op.drop_column("agents", "identity_id")
op.drop_column("agents", "identifier_key")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add back the old columns to agents
op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True))
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))

# Migrate relationships back
op.execute(
"""
UPDATE agents a
SET identity_id = ia.identity_id
FROM identities_agents ia
WHERE a.id = ia.agent_id
"""
)

# Drop the many-to-many table
op.drop_table("identities_agents")

# Drop properties column
op.drop_column("identities", "properties")

# Restore old unique constraint
op.drop_constraint("unique_identifier_without_project", "identities", type_="unique")
op.create_unique_constraint("unique_identifier_pid_org_id", "identities", ["identifier_key", "project_id", "organization_id"])
# ### end Alembic commands ###
7 changes: 6 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def _get_ai_reply(
max_delay: float = 10.0, # max delay between retries
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
Expand Down Expand Up @@ -367,6 +368,7 @@ def _get_ai_reply(
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
)
log_telemetry(self.logger, "_get_ai_reply create finish")

Expand Down Expand Up @@ -648,6 +650,7 @@ def step(
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
Expand All @@ -662,6 +665,7 @@ def step(
kwargs["last_function_failed"] = function_failed
step_response = self.inner_step(
messages=next_input_message,
put_inner_thoughts_first=put_inner_thoughts_first,
**kwargs,
)

Expand Down Expand Up @@ -743,9 +747,9 @@ def inner_step(
metadata: Optional[dict] = None,
summarize_attempt_count: int = 0,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""

try:

# Extract job_id from metadata if present
Expand Down Expand Up @@ -778,6 +782,7 @@ def inner_step(
stream=stream,
step_count=step_count,
last_function_failed=last_function_failed,
put_inner_thoughts_first=put_inner_thoughts_first,
)
if not response:
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
Expand Down
30 changes: 20 additions & 10 deletions letta/llm_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,29 @@ def add_inner_thoughts_to_functions(
inner_thoughts_key: str,
inner_thoughts_description: str,
inner_thoughts_required: bool = True,
put_inner_thoughts_first: bool = True,
) -> List[dict]:
"""Add an inner_thoughts kwarg to every function in the provided list, ensuring it's the first parameter"""
new_functions = []
for function_object in functions:
new_function_object = copy.deepcopy(function_object)

# Create a new OrderedDict with inner_thoughts as the first item
new_properties = OrderedDict()
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}

# Add the rest of the properties
new_properties.update(function_object["parameters"]["properties"])
# For chat completions, we want inner thoughts to come later
if put_inner_thoughts_first:
# Create with inner_thoughts as the first item
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# Add the rest of the properties
new_properties.update(function_object["parameters"]["properties"])
else:
new_properties.update(function_object["parameters"]["properties"])
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}

# Cast OrderedDict back to a regular dict
new_function_object["parameters"]["properties"] = dict(new_properties)
Expand All @@ -225,9 +233,11 @@ def add_inner_thoughts_to_functions(
if inner_thoughts_required:
required_params = new_function_object["parameters"].get("required", [])
if inner_thoughts_key not in required_params:
required_params.insert(0, inner_thoughts_key)
if put_inner_thoughts_first:
required_params.insert(0, inner_thoughts_key)
else:
required_params.append(inner_thoughts_key)
new_function_object["parameters"]["required"] = required_params

new_functions.append(new_function_object)

return new_functions
Expand Down
5 changes: 4 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def create(
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from letta.utils import printd
Expand Down Expand Up @@ -185,7 +186,9 @@ def create(
else:
function_call = "required"

data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, put_inner_thoughts_first=put_inner_thoughts_first
)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
Expand Down
4 changes: 3 additions & 1 deletion letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def build_openai_chat_completions_request(
functions: Optional[list],
function_call: Optional[str],
use_tool_naming: bool,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionRequest:
if functions and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
Expand All @@ -105,6 +106,7 @@ def build_openai_chat_completions_request(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=inner_thoughts_desc,
put_inner_thoughts_first=put_inner_thoughts_first,
)

openai_message_list = [
Expand Down Expand Up @@ -390,7 +392,7 @@ def openai_chat_completions_process_stream(
chat_completion_response.usage.completion_tokens = n_chunks
chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks

assert len(chat_completion_response.choices) > 0, chat_completion_response
assert len(chat_completion_response.choices) > 0, f"No response from provider {chat_completion_response}"

# printd(chat_completion_response)
return chat_completion_response
Expand Down
1 change: 1 addition & 0 deletions letta/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identity import Identity
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage
Expand Down
20 changes: 9 additions & 11 deletions letta/orm/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, Boolean, ForeignKey, Index, String
from sqlalchemy import JSON, Boolean, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.block import Block
Expand Down Expand Up @@ -61,14 +61,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.")
base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.")

# Identity
identity_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
)
identifier_key: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
)

# Tool rules
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")

Expand All @@ -79,7 +71,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable",
back_populates="agent",
Expand Down Expand Up @@ -130,7 +121,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived created by this agent.",
)
identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents")
identities: Mapped[List["Identity"]] = relationship(
"Identity",
secondary="identities_agents",
lazy="selectin",
back_populates="agents",
passive_deletes=True,
)

def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
Expand Down Expand Up @@ -160,6 +157,7 @@ def to_pydantic(self) -> PydanticAgentState:
"project_id": self.project_id,
"template_id": self.template_id,
"base_template_id": self.base_template_id,
"identity_ids": [identity.id for identity in self.identities],
"message_buffer_autoclear": self.message_buffer_autoclear,
}

Expand Down
13 changes: 13 additions & 0 deletions letta/orm/identities_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column

from letta.orm.base import Base


class IdentitiesAgents(Base):
"""Identities may have one or many agents associated with them."""

__tablename__ = "identities_agents"

identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
31 changes: 26 additions & 5 deletions letta/orm/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,49 @@
from typing import List, Optional

from sqlalchemy import String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityProperty


class Identity(SqlalchemyBase, OrganizationMixin):
"""Identity ORM class"""

__tablename__ = "identities"
__pydantic_model__ = PydanticIdentity
__table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),)
__table_args__ = (
UniqueConstraint(
"identifier_key",
"project_id",
"organization_id",
name="unique_identifier_without_project",
postgresql_nulls_not_distinct=True,
),
)

id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}")
identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.")
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.")
identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.")
project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.")
properties: Mapped[List["IdentityProperty"]] = mapped_column(
JSONB, nullable=False, default=list, doc="List of properties associated with the identity"
)

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity")
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
)

@property
def agent_ids(self) -> List[str]:
"""Get just the agent IDs without loading the full agent objects"""
return [agent.id for agent in self.agents]

def to_pydantic(self) -> PydanticIdentity:
state = {
Expand All @@ -33,7 +53,8 @@ def to_pydantic(self) -> PydanticIdentity:
"name": self.name,
"identity_type": self.identity_type,
"project_id": self.project_id,
"agents": [agent.to_pydantic() for agent in self.agents],
"agent_ids": self.agent_ids,
"organization_id": self.organization_id,
"properties": self.properties,
}

return self.__pydantic_model__(**state)
return PydanticIdentity(**state)
4 changes: 4 additions & 0 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def list(
access_type: AccessType = AccessType.ORGANIZATION,
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
Expand Down Expand Up @@ -143,6 +144,9 @@ def list(
# Group by primary key and all necessary columns to avoid JSON comparison
query = query.group_by(cls.id)

if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))

# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key:
Expand Down
Loading

0 comments on commit d23478f

Please sign in to comment.