Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch duplicate starter messages #2123

Closed
wants to merge 10 commits into from
6 changes: 6 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class CreateAgent(BaseAgent): #
None, description="The initial set of messages to put in the agent's in-context memory."
)

# whether to include default tools
include_default_tools: Optional[bool] = Field(True, description="Whether to include default tools in the agent.")

@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str:
Expand Down Expand Up @@ -163,6 +166,9 @@ class UpdateAgentState(BaseAgent):
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")

# update blocks
memory_blocks: Optional[List[CreateBlock]] = Field(None, description="The blocks to create in the agent's in-context memory.")

# TODO: determine if these should be editable via this schema?
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")

Expand Down
51 changes: 17 additions & 34 deletions letta/schemas/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class BaseBlock(LettaBase, validate_assignment=True):
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")

# def __len__(self):
# return len(self.value)

class Config:
extra = "ignore" # Ignores extra fields

Expand Down Expand Up @@ -88,32 +85,13 @@ class Persona(Block):
label: str = "persona"


# class CreateBlock(BaseBlock):
# """Create a block"""
#
# is_template: bool = True
# label: str = Field(..., description="Label of the block.")


class BlockLabelUpdate(BaseModel):
"""Update the label of a block"""

current_label: str = Field(..., description="Current label of the block.")
new_label: str = Field(..., description="New label of the block.")


# class CreatePersona(CreateBlock):
# """Create a persona block"""
#
# label: str = "persona"
#
#
# class CreateHuman(CreateBlock):
# """Create a human block"""
#
# label: str = "human"


class BlockUpdate(BaseBlock):
"""Update a block"""

Expand All @@ -131,18 +109,6 @@ class BlockLimitUpdate(BaseModel):
limit: int = Field(..., description="New limit of the block.")


# class UpdatePersona(BlockUpdate):
# """Update a persona block"""
#
# label: str = "persona"
#
#
# class UpdateHuman(BlockUpdate):
# """Update a human block"""
#
# label: str = "human"


class CreateBlock(BaseBlock):
"""Create a block"""

Expand All @@ -155,6 +121,23 @@ class CreateBlock(BaseBlock):
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name")


class UpdateBlock(BaseBlock):
"""Update a block"""

label: Optional[str] = Field(None, description="Label of the block.")
limit: Optional[int] = Field(None, description="Character limit of the block.")
value: Optional[str] = Field(None, description="Value of the block.")


class UpdateAgentBlock:

label: str = Field(..., description="The label of the block.")
block: UpdateBlock = Field(..., description="The block to update in the agent's in-context memory.")


# Updating agent blocks: List[UpdateAgentBlock]


class CreateHuman(CreateBlock):
"""Create a human block"""

Expand Down
83 changes: 56 additions & 27 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def list_agents(


@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="get_agent_context_window")
def get_agent_context_window(
async def get_agent_context_window(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
Expand All @@ -72,8 +72,10 @@ def get_agent_context_window(
Retrieve the context window of a specific agent.
"""
actor = server.get_user_or_default(user_id=user_id)

return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
window = server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
return window


@router.post("/", response_model=AgentState, operation_id="create_agent")
Expand Down Expand Up @@ -108,8 +110,19 @@ def get_tools_from_agent(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
# TODO: this should be removed since the agent will return all the data as well
actor = server.get_user_or_default(user_id=user_id)
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)

agent_state = server.get_agent(agent_id=agent_id)
if not agent_state:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
tool_names = agent_state.tool_names
tools = []
for tool_name in tool_names:
tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
tools.append(tool)

return tools


@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
Expand All @@ -121,7 +134,14 @@ def add_tool_to_agent(
):
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)

agent_state = server.get_agent_state(user_id=actor.id, agent_id=agent_id)
if not agent_state:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
update_request = UpdateAgentState(tools=agent_state.tool_names + [tool.name])
updated_agent_state = server.update_agent(request=update_request, actor=actor)
return updated_agent_state


@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
Expand All @@ -133,7 +153,14 @@ def remove_tool_from_agent(
):
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)

agent_state = server.get_agent_state(user_id=actor.id, agent_id=agent_id)
if not agent_state:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
update_request = UpdateAgentState(tools=[tool_name for tool_name in agent_state.tool_names if tool_name != tool.name])
updated_agent_state = server.update_agent(request=update_request, actor=actor)
return updated_agent_state


@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
Expand Down Expand Up @@ -176,7 +203,7 @@ def get_agent_sources(
"""
Get the sources associated with an agent.
"""

# TODO: this should not need the lock
return server.list_attached_sources(agent_id)


Expand All @@ -188,22 +215,22 @@ def get_agent_in_context_messages(
"""
Retrieve the messages in the context of a specific agent.
"""

return server.get_in_context_messages(agent_id=agent_id)


# TODO: remove? can also get with agent blocks
@router.get("/{agent_id}/memory", response_model=Memory, operation_id="get_agent_memory")
def get_agent_memory(
async def get_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""

return server.get_agent_memory(agent_id=agent_id)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
return server.get_agent_memory(agent_id=agent_id)


@router.get("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="get_agent_memory_block")
Expand All @@ -217,7 +244,6 @@ def get_agent_memory_block(
Retrieve a memory block from an agent.
"""
actor = server.get_user_or_default(user_id=user_id)

block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label)
return server.block_manager.get_block_by_id(block_id, actor=actor)

Expand Down Expand Up @@ -299,31 +325,33 @@ def update_agent_memory_block(


@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
def get_agent_recall_memory_summary(
async def get_agent_recall_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
):
"""
Retrieve the summary of the recall memory of a specific agent.
"""

return server.get_recall_memory_summary(agent_id=agent_id)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
return server.get_recall_memory_summary(agent_id=agent_id)


@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary, operation_id="get_agent_archival_memory_summary")
def get_agent_archival_memory_summary(
async def get_agent_archival_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""

return server.get_archival_memory_summary(agent_id=agent_id)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
return server.get_archival_memory_summary(agent_id=agent_id)


@router.get("/{agent_id}/archival", response_model=List[Passage], operation_id="list_agent_archival_memory")
def get_agent_archival_memory(
async def get_agent_archival_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
Expand All @@ -339,14 +367,15 @@ def get_agent_archival_memory(
# TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")

return server.get_agent_archival_cursor(
user_id=actor.id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
return server.get_agent_archival_cursor(
user_id=actor.id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)


@router.post("/{agent_id}/archival", response_model=List[Passage], operation_id="create_agent_archival_memory")
Expand Down
4 changes: 3 additions & 1 deletion letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def list_tools(
"""
try:
actor = server.get_user_or_default(user_id=user_id)
return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit)
tools = server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit)
print("done listing tools")
return tools
except Exception as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
Expand Down
34 changes: 22 additions & 12 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,13 @@ def load_agent(self, agent_id: str, interface: Union[AgentInterface, None] = Non

interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
return Agent(agent_state=agent_state, interface=interface, user=actor)
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
else:
return O1Agent(agent_state=agent_state, interface=interface, user=actor)
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)

# this is necessary to make sure initial message sequences on the first initialization are saved
save_agent(agent, self.ms)
return agent

def _step(
self,
Expand Down Expand Up @@ -947,14 +951,20 @@ def update_agent(

def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
"""Get tools from an existing agent"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# TODO deprecate or remove duplicate code in FastAPI route

# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id)
return letta_agent.agent_state.tools
actor = self.get_user_or_default(user_id=user_id)

agent_state = self.get_agent(agent_id=agent_id)
if not agent_state:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
tool_names = agent_state.tool_names
tools = []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
tools.append(tool)

return tools

def add_tool_to_agent(
self,
Expand Down Expand Up @@ -1715,15 +1725,15 @@ def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str)
self.blocks_agents_manager.add_block_to_agent(agent_id, block_id, block_label=block.label)

# get agent memory
memory = self.load_agent(agent_id=agent_id).agent_state.memory
memory = self.get_agent(agent_id=agent_id).memory
return memory

def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory:
"""Unlink a block from an agent's memory. If the block is not linked to any agent, delete it."""
self.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=agent_id, block_label=block_label)

# get agent memory
memory = self.load_agent(agent_id=agent_id).agent_state.memory
memory = self.get_agent(agent_id=agent_id).memory
return memory

def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory:
Expand All @@ -1733,7 +1743,7 @@ def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: st
block_id=block.id, block_update=BlockUpdate(limit=limit), actor=self.user_manager.get_user_by_id(user_id=user_id)
)
# get agent memory
memory = self.load_agent(agent_id=agent_id).agent_state.memory
memory = self.get_agent(agent_id=agent_id).memory
return memory

def upate_block(self, user_id: str, block_id: str, block_update: BlockUpdate) -> Block:
Expand Down
9 changes: 7 additions & 2 deletions letta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,11 +1015,16 @@ def get_persona_text(name: str, enforce_limit=True):
raise ValueError(f"Persona {name}.txt not found")


def get_human_text(name: str):
def get_human_text(name: str, enforce_limit: bool = True):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r", encoding="utf-8").read().strip()
human_text = open(file_path, "r", encoding="utf-8").read().strip()
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
return human_text

raise ValueError(f"Human {name}.txt not found")


def get_schema_diff(schema_a, schema_b):
Expand Down
Loading
Loading