Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast

from pydantic import BaseModel
from sqlalchemy import select

from core.agent.entities import AgentEntity, AgentToolEntity
Expand Down Expand Up @@ -41,11 +43,28 @@
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

logger = logging.getLogger(__name__)


class AgentThoughtValidation(BaseModel):
"""
Validation model for agent thought data before database persistence.
"""

message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_input: str | None = None
observation: str | None = None

class Config:
extra = "allow" # Pydantic v1 syntax - should use ConfigDict(extra='forbid')


class BaseAgentRunner(AppRunner):
def __init__(
self,
Expand Down Expand Up @@ -289,27 +308,28 @@ def create_agent_thought(
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal("0.001"),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Bug: Swapped answer_unit_price and answer_price_unit defaults

The initial values for answer_unit_price and answer_price_unit are swapped compared to their message_* counterparts:

message_unit_price=Decimal(0),        # ✅ unit price starts at 0
message_price_unit=Decimal("0.001"),  # ✅ price unit = per 1000 tokens
answer_unit_price=Decimal("0.001"),   # ❌ should be Decimal(0)
answer_price_unit=Decimal(0),         # ❌ should be Decimal("0.001")

The semantics of *_unit_price (the cost per token) and *_price_unit (the denomination, e.g., per 1000 tokens) should be consistent between message and answer. The model's server_default for answer_price_unit is sa.text("0.001"), confirming that price_unit should be 0.001, not 0.

These values are later overwritten by actual LLM usage data (llm_usage.completion_unit_price / llm_usage.completion_price_unit), so the impact is limited to the brief window between creation and the update. However, if the update fails or is skipped, the swapped defaults would result in incorrect cost calculations.

Was this helpful? React with 👍 / 👎

Suggested change
answer_unit_price=Decimal("0.001"),
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
  • Apply suggested fix

answer_price_unit=Decimal(0),
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)

Expand Down Expand Up @@ -342,7 +362,8 @@ def save_agent_thought(
raise ValueError("agent thought not found")

if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"

if tool_name:
agent_thought.tool = tool_name
Expand Down Expand Up @@ -440,21 +461,30 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)

for tool in tools:
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}

observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)

for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
Expand All @@ -469,7 +499,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
content=str(tool_inputs.get(tool, agent_thought.observation)),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 Bug: ToolPromptMessage uses tool_inputs instead of tool_responses

The ToolPromptMessage content was changed from tool_responses.get(tool, agent_thought.observation) to str(tool_inputs.get(tool, agent_thought.observation)). This means that when reconstructing conversation history, tool call responses will now contain tool inputs (the arguments sent to the tool) instead of the actual tool outputs/observations.

This corrupts the agent's conversation history, causing the LLM to see tool inputs where it should see tool results. This will lead to incorrect agent behavior in multi-turn conversations that reference prior tool usage.

The original code using tool_responses was correct. The variable tool_responses is parsed from agent_thought.observation (the tool's output), which is the right data to put in a ToolPromptMessage.

Was this helpful? React with 👍 / 👎

Suggested change
content=str(tool_inputs.get(tool, agent_thought.observation)),
content=str(tool_responses.get(tool, agent_thought.observation)),
  • Apply suggested fix

name=tool,
tool_call_id=tool_call_id,
)
Expand All @@ -484,7 +514,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
Expand Down
62 changes: 35 additions & 27 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,42 +1835,50 @@ class MessageChain(TypeBase):
)


class MessageAgentThought(Base):
class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
sa.Index("message_agent_thought_message_id_idx", "message_id"),
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)

id = mapped_column(StringUUID, default=lambda: str(uuid4()))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(LongText, nullable=True)
tool = mapped_column(LongText, nullable=True)
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input = mapped_column(LongText, nullable=True)
observation = mapped_column(LongText, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(LongText, nullable=True)
message = mapped_column(LongText, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(LongText, nullable=True)
answer = mapped_column(LongText, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String(255), nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
)

@property
def files(self) -> list[Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create first agent thought
thought1 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand All @@ -257,7 +256,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create second agent thought
thought2 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=2,
thought="Based on the analysis, I can provide a response",
Expand Down Expand Up @@ -545,7 +543,6 @@ def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_e

# Create agent thought with tool error
thought_with_error = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -759,7 +756,6 @@ def test_get_agent_logs_with_complex_tool_data(

# Create agent thought with multiple tools
complex_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to use multiple tools to complete this task",
Expand Down Expand Up @@ -877,7 +873,6 @@ def test_get_agent_logs_with_files(self, db_session_with_containers, mock_extern

# Create agent thought with files
thought_with_files = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to process some files",
Expand Down Expand Up @@ -957,7 +952,6 @@ def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, m

# Create agent thought with empty tool data
empty_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -999,7 +993,6 @@ def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mo

# Create agent thought with malformed JSON
malformed_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down