Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions athena/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .message import Message

from .memory import (
# BaseModel classes for business logic
MemoryUnit,
MemoryContext,
MemoryTimestamp,
Task,
State,
Action,
Result,
# SQLModel classes for database operations
MemoryUnitDB,
)

__all__ = [
# Message exports
"Message",
# Memory BaseModel exports
"MemoryUnit",
"MemoryContext",
"MemoryTimestamp",
"Task",
"State",
"Action",
"Result",
# Memory SQLModel exports
"MemoryUnitDB",
]
249 changes: 249 additions & 0 deletions athena/models/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import json
from typing import Literal, Optional, Dict, Any
from pydantic import BaseModel, Field
from sqlmodel import SQLModel, UniqueConstraint
from sqlalchemy import Column, DateTime, Text
from datetime import datetime
import uuid


class MemoryTimestamp(BaseModel):
"""Lifecycle timestamps for a memory unit."""

created_at: datetime = Field(None, description="When the memory unit was first created")
updated_at: Optional[datetime] = Field(
None, description="When the memory was last updated/refreshed"
)
invalid_at: Optional[datetime] = Field(
None, description="When the memory was invalidated or expired"
)


class MemoryContext(BaseModel):
"""Context metadata for a memory unit."""

memory_id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique identifier for this memory unit",
)
source: str = Field(
..., description="Memory source, e.g., agent name, model name, dataset name, or file path"
)
run_id: str = Field(..., description="Unique id for agent run or trajectory dataset")
timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Optional extension field for custom metadata"
)


class Task(BaseModel):
"""Task information extracted from messages or issues."""

issue_title: str = Field(..., description="Short title of the issue/task")
issue_body: str = Field(..., description="Detailed description of the issue/task")
issue_comments: str = Field(..., description="Relevant comments or discussions")
issue_type: str = Field(
..., description="Type of issue: bug | feature | documentation | question | other"
)
repository: str = Field(..., description="Repository where the issue/task belongs")


class State(BaseModel):
"""State information including completed and pending work."""

done: str = Field(..., description="Summary of what has already been completed")
todo: str = Field(..., description="Summary of what remains to be done")
open_file: Optional[str] = Field(
None,
description="Path of the currently opened or edited file, if any (e.g., /sympy__sympy/reproduce_bug.py)",
)
working_dir: Optional[str] = Field(
None, description="Path of the current working directory (e.g., /sympy__sympy)"
)
extra_environment: Dict[str, str] = Field(
default_factory=dict,
description="Other SWE-specific runtime state, e.g. active branch, virtualenv, configs, etc.",
)


class Action(BaseModel):
"""Action taken by the agent."""

name: str = Field(
..., description="Action type (read_file | edit_file | run_test | invoke_tool | etc.)"
)
description: str = Field(..., description="Detailed description of the action")
target: str = Field(
...,
description="Target of the action (e.g., utils/math.py, tests/test_math.py, pytest, git)",
)
tool: str = Field(
..., description="Tool used for the action (e.g., pytest, git, search, editor, bash)"
)


ResultType = Literal["success", "failure", "partial", "unknown"]


class Result(BaseModel):
"""Execution result produced by an action."""

type: ResultType = Field(
"unknown", description="Execution outcome: success | failure | partial | unknown"
)
description: str = Field("", description="Summary of the result")
exit_code: Optional[int] = Field(
None, description="Exit code if extractable from logs (optional)"
)


class MemoryUnit(BaseModel):
"""
Core memory unit capturing one action of agent execution.

This includes:
- The memory context (memory_id, source, run_id, timestamps, metadata)
- The task being worked on (issue and repository details)
- The current state (what's done, what's todo)
- The action taken by the agent
- The result of the action
"""

context: MemoryContext
task: Task
state: State
action: Action
result: Result


# Database models for persistent storage
class MemoryUnitDB(SQLModel, table=True):
"""Database model for persistent storage of memory units."""

__tablename__ = "memory_units"
__table_args__ = UniqueConstraint("memory_id", name="uq_memory_id")

id: Optional[int] = Field(default=None, primary_key=True)

# Context information
memory_id: str
memory_source: str
memory_run_id: str
memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
memory_updated_at: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
memory_invalid_at: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
memory_metadata: str = Field(
default="{}", sa_column=Column(Text)
) # JSON string for Dict[str, Any]

# Task information
task_issue_title: str = Field(sa_column=Column(Text))
task_issue_body: str = Field(sa_column=Column(Text))
task_issue_comments: str = Field(sa_column=Column(Text))
task_issue_type: str
task_repository: str

# State information
state_done: str = Field(sa_column=Column(Text))
state_todo: str = Field(sa_column=Column(Text))
state_open_file: Optional[str] = None
state_working_dir: Optional[str] = None
state_extra_environment: str = Field(
default="{}", sa_column=Column(Text)
) # JSON string for Dict[str, str]

# Action information
action_name: str
action_description: str = Field(sa_column=Column(Text))
action_target: str
action_tool: str

# Result information
result_type: str
result_description: str = Field(sa_column=Column(Text))
result_exit_code: Optional[int] = None

@classmethod
def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB":
"""Create a database model from a MemoryUnit."""
return cls(
memory_id=memory_unit.context.memory_id,
memory_source=memory_unit.context.source,
memory_run_id=memory_unit.context.run_id,
memory_created_at=memory_unit.context.timestamp.created_at,
memory_updated_at=memory_unit.context.timestamp.updated_at,
memory_invalid_at=memory_unit.context.timestamp.invalid_at,
memory_metadata=json.dumps(memory_unit.context.metadata)
if memory_unit.context.metadata
else "{}",
task_issue_title=memory_unit.task.issue_title,
task_issue_body=memory_unit.task.issue_body,
task_issue_comments=memory_unit.task.issue_comments,
task_issue_type=memory_unit.task.issue_type,
task_repository=memory_unit.task.repository,
state_done=memory_unit.state.done,
state_todo=memory_unit.state.todo,
state_open_file=memory_unit.state.open_file,
state_working_dir=memory_unit.state.working_dir,
state_extra_environment=json.dumps(memory_unit.state.extra_environment)
if memory_unit.state.extra_environment
else "{}",
action_name=memory_unit.action.name,
action_description=memory_unit.action.description,
action_target=memory_unit.action.target,
action_tool=memory_unit.action.tool,
result_type=memory_unit.result.type,
result_description=memory_unit.result.description,
result_exit_code=memory_unit.result.exit_code,
)

def to_memory_unit(self) -> MemoryUnit:
"""Convert database model back to MemoryUnit."""
return MemoryUnit(
context=MemoryContext(
memory_id=self.memory_id,
source=self.memory_source,
run_id=self.memory_run_id,
timestamp=MemoryTimestamp(
created_at=self.memory_created_at,
updated_at=self.memory_updated_at,
invalid_at=self.memory_invalid_at,
),
metadata=json.loads(self.memory_metadata)
if self.memory_metadata not in (None, "", "null")
else {},
),
task=Task(
issue_title=self.task_issue_title,
issue_body=self.task_issue_body,
issue_comments=self.task_issue_comments,
issue_type=self.task_issue_type,
repository=self.task_repository,
),
state=State(
done=self.state_done,
todo=self.state_todo,
open_file=self.state_open_file,
working_dir=self.state_working_dir,
extra_environment=json.loads(self.state_extra_environment)
if self.state_extra_environment not in (None, "", "null")
else {},
),
action=Action(
name=self.action_name,
description=self.action_description,
target=self.action_target,
tool=self.action_tool,
),
result=Result(
type=self.result_type
if self.result_type in ["success", "failure", "partial", "unknown"]
else "unknown",
description=self.result_description,
exit_code=self.result_exit_code,
),
)
27 changes: 27 additions & 0 deletions athena/models/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict
from pydantic import BaseModel, field_validator


ROLE_ALIASES: Dict[str, str] = {
"ai": "assistant",
"agent": "assistant",
"assistant": "assistant",
"bot": "assistant",
"model": "assistant",
"llm": "assistant",
"sys": "system",
"system_prompt": "system",
"sysmsg": "system",
}


class Message(BaseModel):
content: str
role: str # system | user | assistant | tool | environment | unknown
metadata: dict = {}

@field_validator("role", mode="before")
@classmethod
def normalize_role_field(cls, raw: str) -> str:
r = (raw or "").strip().lower()
return ROLE_ALIASES.get(r, r)
Loading