From d9d9c306872321ffe5baef02a559a240915976ed Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Wed, 10 Sep 2025 19:01:03 +0800 Subject: [PATCH] Upload data models --- athena/models/__init__.py | 29 +++++ athena/models/memory.py | 249 ++++++++++++++++++++++++++++++++++++++ athena/models/message.py | 27 +++++ 3 files changed, 305 insertions(+) create mode 100644 athena/models/__init__.py create mode 100644 athena/models/memory.py create mode 100644 athena/models/message.py diff --git a/athena/models/__init__.py b/athena/models/__init__.py new file mode 100644 index 0000000..4730359 --- /dev/null +++ b/athena/models/__init__.py @@ -0,0 +1,29 @@ +from .message import Message + +from .memory import ( + # BaseModel classes for business logic + MemoryUnit, + MemoryContext, + MemoryTimestamp, + Task, + State, + Action, + Result, + # SQLModel classes for database operations + MemoryUnitDB, +) + +__all__ = [ + # Message exports + "Message", + # Memory BaseModel exports + "MemoryUnit", + "MemoryContext", + "MemoryTimestamp", + "Task", + "State", + "Action", + "Result", + # Memory SQLModel exports + "MemoryUnitDB", +] diff --git a/athena/models/memory.py b/athena/models/memory.py new file mode 100644 index 0000000..6f6a8fb --- /dev/null +++ b/athena/models/memory.py @@ -0,0 +1,249 @@ +import json +from typing import Literal, Optional, Dict, Any +from pydantic import BaseModel, Field +from sqlmodel import SQLModel, UniqueConstraint +from sqlalchemy import Column, DateTime, Text +from datetime import datetime +import uuid + + +class MemoryTimestamp(BaseModel): + """Lifecycle timestamps for a memory unit.""" + + created_at: datetime = Field(None, description="When the memory unit was first created") + updated_at: Optional[datetime] = Field( + None, description="When the memory was last updated/refreshed" + ) + invalid_at: Optional[datetime] = Field( + None, description="When the memory was invalidated or expired" + ) + + +class MemoryContext(BaseModel): + """Context metadata for a memory unit.""" + + memory_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for this memory unit", + ) + source: str = Field( + ..., description="Memory source, e.g., agent name, model name, dataset name, or file path" + ) + run_id: str = Field(..., description="Unique id for agent run or trajectory dataset") + timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Optional extension field for custom metadata" + ) + + +class Task(BaseModel): + """Task information extracted from messages or issues.""" + + issue_title: str = Field(..., description="Short title of the issue/task") + issue_body: str = Field(..., description="Detailed description of the issue/task") + issue_comments: str = Field(..., description="Relevant comments or discussions") + issue_type: str = Field( + ..., description="Type of issue: bug | feature | documentation | question | other" + ) + repository: str = Field(..., description="Repository where the issue/task belongs") + + +class State(BaseModel): + """State information including completed and pending work.""" + + done: str = Field(..., description="Summary of what has already been completed") + todo: str = Field(..., description="Summary of what remains to be done") + open_file: Optional[str] = Field( + None, + description="Path of the currently opened or edited file, if any (e.g., /sympy__sympy/reproduce_bug.py)", + ) + working_dir: Optional[str] = Field( + None, description="Path of the current working directory (e.g., /sympy__sympy)" + ) + extra_environment: Dict[str, str] = Field( + default_factory=dict, + description="Other SWE-specific runtime state, e.g. active branch, virtualenv, configs, etc.", + ) + + +class Action(BaseModel): + """Action taken by the agent.""" + + name: str = Field( + ..., description="Action type (read_file | edit_file | run_test | invoke_tool | etc.)" + ) + description: str = Field(..., description="Detailed description of the action") + target: str = Field( + ..., + description="Target of the action (e.g., utils/math.py, tests/test_math.py, pytest, git)", + ) + tool: str = Field( + ..., description="Tool used for the action (e.g., pytest, git, search, editor, bash)" + ) + + +ResultType = Literal["success", "failure", "partial", "unknown"] + + +class Result(BaseModel): + """Execution result produced by an action.""" + + type: ResultType = Field( + "unknown", description="Execution outcome: success | failure | partial | unknown" + ) + description: str = Field("", description="Summary of the result") + exit_code: Optional[int] = Field( + None, description="Exit code if extractable from logs (optional)" + ) + + +class MemoryUnit(BaseModel): + """ + Core memory unit capturing one action of agent execution. + + This includes: + - The memory context (memory_id, source, run_id, timestamps, metadata) + - The task being worked on (issue and repository details) + - The current state (what's done, what's todo) + - The action taken by the agent + - The result of the action + """ + + context: MemoryContext + task: Task + state: State + action: Action + result: Result + + +# Database models for persistent storage +class MemoryUnitDB(SQLModel, table=True): + """Database model for persistent storage of memory units.""" + + __tablename__ = "memory_units" + __table_args__ = UniqueConstraint("memory_id", name="uq_memory_id") + + id: Optional[int] = Field(default=None, primary_key=True) + + # Context information + memory_id: str + memory_source: str + memory_run_id: str + memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) + memory_updated_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_invalid_at: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + memory_metadata: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, Any] + + # Task information + task_issue_title: str = Field(sa_column=Column(Text)) + task_issue_body: str = Field(sa_column=Column(Text)) + task_issue_comments: str = Field(sa_column=Column(Text)) + task_issue_type: str + task_repository: str + + # State information + state_done: str = Field(sa_column=Column(Text)) + state_todo: str = Field(sa_column=Column(Text)) + state_open_file: Optional[str] = None + state_working_dir: Optional[str] = None + state_extra_environment: str = Field( + default="{}", sa_column=Column(Text) + ) # JSON string for Dict[str, str] + + # Action information + action_name: str + action_description: str = Field(sa_column=Column(Text)) + action_target: str + action_tool: str + + # Result information + result_type: str + result_description: str = Field(sa_column=Column(Text)) + result_exit_code: Optional[int] = None + + @classmethod + def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": + """Create a database model from a MemoryUnit.""" + return cls( + memory_id=memory_unit.context.memory_id, + memory_source=memory_unit.context.source, + memory_run_id=memory_unit.context.run_id, + memory_created_at=memory_unit.context.timestamp.created_at, + memory_updated_at=memory_unit.context.timestamp.updated_at, + memory_invalid_at=memory_unit.context.timestamp.invalid_at, + memory_metadata=json.dumps(memory_unit.context.metadata) + if memory_unit.context.metadata + else "{}", + task_issue_title=memory_unit.task.issue_title, + task_issue_body=memory_unit.task.issue_body, + task_issue_comments=memory_unit.task.issue_comments, + task_issue_type=memory_unit.task.issue_type, + task_repository=memory_unit.task.repository, + state_done=memory_unit.state.done, + state_todo=memory_unit.state.todo, + state_open_file=memory_unit.state.open_file, + state_working_dir=memory_unit.state.working_dir, + state_extra_environment=json.dumps(memory_unit.state.extra_environment) + if memory_unit.state.extra_environment + else "{}", + action_name=memory_unit.action.name, + action_description=memory_unit.action.description, + action_target=memory_unit.action.target, + action_tool=memory_unit.action.tool, + result_type=memory_unit.result.type, + result_description=memory_unit.result.description, + result_exit_code=memory_unit.result.exit_code, + ) + + def to_memory_unit(self) -> MemoryUnit: + """Convert database model back to MemoryUnit.""" + return MemoryUnit( + context=MemoryContext( + memory_id=self.memory_id, + source=self.memory_source, + run_id=self.memory_run_id, + timestamp=MemoryTimestamp( + created_at=self.memory_created_at, + updated_at=self.memory_updated_at, + invalid_at=self.memory_invalid_at, + ), + metadata=json.loads(self.memory_metadata) + if self.memory_metadata not in (None, "", "null") + else {}, + ), + task=Task( + issue_title=self.task_issue_title, + issue_body=self.task_issue_body, + issue_comments=self.task_issue_comments, + issue_type=self.task_issue_type, + repository=self.task_repository, + ), + state=State( + done=self.state_done, + todo=self.state_todo, + open_file=self.state_open_file, + working_dir=self.state_working_dir, + extra_environment=json.loads(self.state_extra_environment) + if self.state_extra_environment not in (None, "", "null") + else {}, + ), + action=Action( + name=self.action_name, + description=self.action_description, + target=self.action_target, + tool=self.action_tool, + ), + result=Result( + type=self.result_type + if self.result_type in ["success", "failure", "partial", "unknown"] + else "unknown", + description=self.result_description, + exit_code=self.result_exit_code, + ), + ) diff --git a/athena/models/message.py b/athena/models/message.py new file mode 100644 index 0000000..5c97cd8 --- /dev/null +++ b/athena/models/message.py @@ -0,0 +1,27 @@ +from typing import Dict +from pydantic import BaseModel, field_validator + + +ROLE_ALIASES: Dict[str, str] = { + "ai": "assistant", + "agent": "assistant", + "assistant": "assistant", + "bot": "assistant", + "model": "assistant", + "llm": "assistant", + "sys": "system", + "system_prompt": "system", + "sysmsg": "system", +} + + +class Message(BaseModel): + content: str + role: str # system | user | assistant | tool | environment | unknown + metadata: dict = {} + + @field_validator("role", mode="before") + @classmethod + def normalize_role_field(cls, raw: str) -> str: + r = (raw or "").strip().lower() + return ROLE_ALIASES.get(r, r)