From 9ddb376f318d1cec0c6ee202966169fc8d639690 Mon Sep 17 00:00:00 2001 From: xinquiry Date: Wed, 7 Jan 2026 21:59:39 +0800 Subject: [PATCH 1/4] feat: add new schemas --- service/app/models/agent.py | 7 + service/app/schemas/agent_events.py | 237 ++++++++++++++ service/app/schemas/chat_event_types.py | 107 +++++++ service/app/schemas/chat_events.py | 52 +++ service/app/schemas/graph_config.py | 407 ++++++++++++++++++++++++ service/app/tasks/chat.py | 16 +- 6 files changed, 818 insertions(+), 8 deletions(-) create mode 100644 service/app/schemas/agent_events.py create mode 100644 service/app/schemas/graph_config.py diff --git a/service/app/models/agent.py b/service/app/models/agent.py index 199cbbc8..f70265a7 100644 --- a/service/app/models/agent.py +++ b/service/app/models/agent.py @@ -16,6 +16,13 @@ class AgentScope(StrEnum): USER = "user" +class AgentType(StrEnum): + """Type of agent determining execution strategy.""" + + GRAPH = "graph" # JSON-configured graph agent (user-customizable) + SYSTEM = "system" # Python-coded system agent (e.g., react, deep_research) + + class AgentBase(SQLModel): scope: AgentScope = Field( sa_column=sa.Column( diff --git a/service/app/schemas/agent_events.py b/service/app/schemas/agent_events.py new file mode 100644 index 00000000..b4b95648 --- /dev/null +++ b/service/app/schemas/agent_events.py @@ -0,0 +1,237 @@ +""" +Agent Event Data Types - Typed dictionaries for agent execution events. + +This module defines the data structures for complex agent execution events, +providing flat context metadata for tracking nested agent execution. +""" + +from __future__ import annotations + +from typing import Any, TypedDict + +from typing_extensions import NotRequired + + +class AgentExecutionContext(TypedDict): + """ + Flat context metadata included with all agent events. + + This context allows the frontend to: + - Track which agent is executing + - Understand execution depth for nested agents + - Visualize the execution path + - Measure timing + """ + + # Agent identification + agent_id: str # UUID of the executing agent + agent_name: str # Human-readable name + agent_type: str # "react", "graph", "system" + + # Execution tracking (flat, not hierarchical) + execution_id: str # Unique ID for this execution run + parent_execution_id: NotRequired[str] # Present if this is a subagent + depth: int # 0 for root agent, 1 for first subagent, etc. + execution_path: list[str] # Path of agent names: ["root", "deep_research", "web_search"] + + # Current position in graph + current_node: NotRequired[str] # Current node ID + current_phase: NotRequired[str] # Current phase name (if applicable) + + # Timing + started_at: float # Unix timestamp when execution started + elapsed_ms: NotRequired[int] # Milliseconds since started_at + + +# === Agent Lifecycle Events === + + +class AgentStartData(TypedDict): + """Data for AGENT_START event.""" + + context: AgentExecutionContext + total_nodes: NotRequired[int] # Total nodes in graph + estimated_duration_ms: NotRequired[int] + + +class AgentEndData(TypedDict): + """Data for AGENT_END event.""" + + context: AgentExecutionContext + status: str # "completed", "failed", "cancelled" + duration_ms: int + output_summary: NotRequired[str] + + +class AgentErrorData(TypedDict): + """Data for AGENT_ERROR event.""" + + context: AgentExecutionContext + error_type: str # Error class name + error_message: str + recoverable: bool + node_id: NotRequired[str] # Node where error occurred + + +# === Phase Events === + + +class PhaseStartData(TypedDict): + """Data for PHASE_START event.""" + + phase_id: str + phase_name: str + description: NotRequired[str] + expected_duration_ms: NotRequired[int] + context: AgentExecutionContext + + +class PhaseEndData(TypedDict): + """Data for PHASE_END event.""" + + phase_id: str + phase_name: str + status: str # "completed", "failed", "skipped" + duration_ms: int + output_summary: NotRequired[str] + context: AgentExecutionContext + + +# === Node Events === + + +class NodeStartData(TypedDict): + """Data for NODE_START event.""" + + node_id: str + node_name: str + node_type: str # "llm", "tool", "router", etc. + input_summary: NotRequired[str] + context: AgentExecutionContext + + +class NodeEndData(TypedDict): + """Data for NODE_END event.""" + + node_id: str + node_name: str + node_type: str + status: str # "completed", "failed", "skipped" + duration_ms: int + output_summary: NotRequired[str] + context: AgentExecutionContext + + +# === Subagent Events === + + +class SubagentStartData(TypedDict): + """Data for SUBAGENT_START event.""" + + subagent_id: str + subagent_name: str + subagent_type: str # Agent type of the subagent + input_summary: NotRequired[str] + context: AgentExecutionContext + + +class SubagentEndData(TypedDict): + """Data for SUBAGENT_END event.""" + + subagent_id: str + subagent_name: str + status: str + duration_ms: int + output_summary: NotRequired[str] + context: AgentExecutionContext + + +# === Progress Events === + + +class ProgressUpdateData(TypedDict): + """Data for PROGRESS_UPDATE event.""" + + progress_percent: int # 0-100 + message: str # Human-readable progress message + details: NotRequired[dict[str, Any]] # Additional structured data + context: AgentExecutionContext + + +# === Iteration Events === + + +class IterationStartData(TypedDict): + """Data for ITERATION_START event.""" + + iteration_number: int # 1-indexed + max_iterations: int + reason: NotRequired[str] # Why iteration is needed + context: AgentExecutionContext + + +class IterationEndData(TypedDict): + """Data for ITERATION_END event.""" + + iteration_number: int + will_continue: bool # Whether another iteration will follow + reason: NotRequired[str] # Why continuing or stopping + context: AgentExecutionContext + + +# === State Events === + + +class StateUpdateData(TypedDict): + """ + Data for STATE_UPDATE event. + + Only includes non-sensitive state changes that are safe to display. + """ + + updated_keys: list[str] # State keys that changed + summary: dict[str, str] # Key -> human-readable summary of new value + context: AgentExecutionContext + + +# === Human-in-the-Loop Events === + + +class HumanInputRequiredData(TypedDict): + """Data for HUMAN_INPUT_REQUIRED event.""" + + prompt: str # Message to display to user + input_type: str # "text", "choice", "confirm", "form" + choices: NotRequired[list[str]] # For "choice" type + form_schema: NotRequired[dict[str, Any]] # JSON Schema for "form" type + timeout_seconds: NotRequired[int] + context: AgentExecutionContext + + +class HumanInputReceivedData(TypedDict): + """Data for HUMAN_INPUT_RECEIVED event.""" + + input_value: Any # The user's input + input_type: str + context: AgentExecutionContext + + +# Export all types +__all__ = [ + "AgentExecutionContext", + "AgentStartData", + "AgentEndData", + "AgentErrorData", + "PhaseStartData", + "PhaseEndData", + "NodeStartData", + "NodeEndData", + "SubagentStartData", + "SubagentEndData", + "ProgressUpdateData", + "IterationStartData", + "IterationEndData", + "StateUpdateData", + "HumanInputRequiredData", + "HumanInputReceivedData", +] diff --git a/service/app/schemas/chat_event_types.py b/service/app/schemas/chat_event_types.py index 2023d889..2cf4443c 100644 --- a/service/app/schemas/chat_event_types.py +++ b/service/app/schemas/chat_event_types.py @@ -15,6 +15,7 @@ """ from typing import Any, TypedDict + from typing_extensions import Literal, NotRequired from app.schemas.chat_events import ChatEventType @@ -169,6 +170,61 @@ class ThinkingEndData(TypedDict): id: str +class AgentContextData(TypedDict): + """Common context shared by agent and node events.""" + + agent_id: str + agent_name: str + agent_type: str + execution_id: str + depth: int + execution_path: list[str] + started_at: int + current_node: NotRequired[str] + + +class AgentStartData(TypedDict): + """Data payload for AGENT_START event.""" + + context: AgentContextData + + +class AgentEndData(TypedDict): + """Data payload for AGENT_END event.""" + + context: AgentContextData + status: str + duration_ms: int + + +class NodeStartData(TypedDict): + """Data payload for NODE_START event.""" + + node_id: str + node_name: str + node_type: str + context: AgentContextData + + +class NodeEndData(TypedDict): + """Data payload for NODE_END event.""" + + node_id: str + node_name: str + node_type: str + status: str + duration_ms: int + context: AgentContextData + + +class ProgressUpdateData(TypedDict): + """Data payload for PROGRESS_UPDATE event.""" + + progress_percent: int + message: str + context: AgentContextData + + # ============================================================================= # Full Event Structures (type + data) # ============================================================================= @@ -293,6 +349,41 @@ class ThinkingEndEvent(TypedDict): data: ThinkingEndData +class AgentStartEvent(TypedDict): + """Full event structure for agent start.""" + + type: Literal[ChatEventType.AGENT_START] + data: AgentStartData + + +class AgentEndEvent(TypedDict): + """Full event structure for agent end.""" + + type: Literal[ChatEventType.AGENT_END] + data: AgentEndData + + +class NodeStartEvent(TypedDict): + """Full event structure for node start.""" + + type: Literal[ChatEventType.NODE_START] + data: NodeStartData + + +class NodeEndEvent(TypedDict): + """Full event structure for node end.""" + + type: Literal[ChatEventType.NODE_END] + data: NodeEndData + + +class ProgressUpdateEvent(TypedDict): + """Full event structure for progress updates.""" + + type: Literal[ChatEventType.PROGRESS_UPDATE] + data: ProgressUpdateData + + # ============================================================================= # Union type for generic event handling # ============================================================================= @@ -316,6 +407,11 @@ class ThinkingEndEvent(TypedDict): | ThinkingStartEvent | ThinkingChunkEvent | ThinkingEndEvent + | AgentStartEvent + | AgentEndEvent + | NodeStartEvent + | NodeEndEvent + | ProgressUpdateEvent ) @@ -358,6 +454,17 @@ class ThinkingEndEvent(TypedDict): "ThinkingStartEvent", "ThinkingChunkEvent", "ThinkingEndEvent", + "AgentContextData", + "AgentStartData", + "AgentEndData", + "NodeStartData", + "NodeEndData", + "ProgressUpdateData", + "AgentStartEvent", + "AgentEndEvent", + "NodeStartEvent", + "NodeEndEvent", + "ProgressUpdateEvent", # Union "StreamingEvent", ] diff --git a/service/app/schemas/chat_events.py b/service/app/schemas/chat_events.py index 5207ee45..5db68be0 100644 --- a/service/app/schemas/chat_events.py +++ b/service/app/schemas/chat_events.py @@ -52,6 +52,39 @@ class ChatEventType(StrEnum): THINKING_CHUNK = "thinking_chunk" THINKING_END = "thinking_end" + # === Agent Execution Events (for complex graph-based agents) === + + # Agent lifecycle + AGENT_START = "agent_start" + AGENT_END = "agent_end" + AGENT_ERROR = "agent_error" + + # Phase/workflow stage execution + PHASE_START = "phase_start" + PHASE_END = "phase_end" + + # Individual node execution + NODE_START = "node_start" + NODE_END = "node_end" + + # Subagent execution (nested agents) + SUBAGENT_START = "subagent_start" + SUBAGENT_END = "subagent_end" + + # Progress updates + PROGRESS_UPDATE = "progress_update" + + # State changes (non-sensitive updates) + STATE_UPDATE = "state_update" + + # Iteration events (for loops in graph execution) + ITERATION_START = "iteration_start" + ITERATION_END = "iteration_end" + + # Human-in-the-loop events + HUMAN_INPUT_REQUIRED = "human_input_required" + HUMAN_INPUT_RECEIVED = "human_input_received" + class ChatClientEventType(StrEnum): """Client -> Server event types (messages coming from the frontend).""" @@ -94,6 +127,24 @@ class ProcessingStatus(StrEnum): {ChatEventType.TOOL_CALL_REQUEST, ChatEventType.TOOL_CALL_RESPONSE} ) +SERVER_AGENT_EVENTS: FrozenSet[ChatEventType] = frozenset( + { + ChatEventType.AGENT_START, + ChatEventType.AGENT_END, + ChatEventType.AGENT_ERROR, + ChatEventType.PHASE_START, + ChatEventType.PHASE_END, + ChatEventType.NODE_START, + ChatEventType.NODE_END, + ChatEventType.SUBAGENT_START, + ChatEventType.SUBAGENT_END, + ChatEventType.PROGRESS_UPDATE, + ChatEventType.STATE_UPDATE, + ChatEventType.ITERATION_START, + ChatEventType.ITERATION_END, + } +) + __all__ = [ "ChatEventType", "ChatClientEventType", @@ -101,4 +152,5 @@ class ProcessingStatus(StrEnum): "ProcessingStatus", "SERVER_STREAMING_EVENTS", "SERVER_TOOL_EVENTS", + "SERVER_AGENT_EVENTS", ] diff --git a/service/app/schemas/graph_config.py b/service/app/schemas/graph_config.py new file mode 100644 index 00000000..76f39add --- /dev/null +++ b/service/app/schemas/graph_config.py @@ -0,0 +1,407 @@ +""" +Graph Configuration Schema for JSON-configurable agents. + +This module defines the complete JSON schema for graph-based agents, +allowing users to configure agent workflows via JSON configuration. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class NodeType(StrEnum): + """Types of nodes in a graph agent.""" + + LLM = "llm" # LLM reasoning node + TOOL = "tool" # Tool execution node + ROUTER = "router" # Conditional routing node + SUBAGENT = "subagent" # Nested agent invocation + TRANSFORM = "transform" # Data transformation + PARALLEL = "parallel" # Parallel execution of multiple branches + HUMAN = "human" # Human-in-the-loop checkpoint + + +class ConditionOperator(StrEnum): + """Operators for edge conditions.""" + + EQUALS = "eq" + NOT_EQUALS = "neq" + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL = "gte" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL = "lte" + IN = "in" + NOT_IN = "not_in" + TRUTHY = "truthy" + FALSY = "falsy" + MATCHES = "matches" # Regex match + + +class ReducerType(StrEnum): + """Reducer types for state fields with multiple updates.""" + + REPLACE = "replace" # Replace value (default) + APPEND = "append" # Append to list + MERGE = "merge" # Merge dictionaries + ADD = "add" # Add numbers + MESSAGES = "messages" # Special reducer for message lists + + +# --- State Schema Definitions --- + + +class StateFieldSchema(BaseModel): + """Schema for a single state field.""" + + type: str = Field(description="Field type: 'string', 'int', 'float', 'bool', 'list', 'dict', 'messages', 'any'") + description: str | None = Field(default=None, description="Human-readable field description") + default: Any = Field(default=None, description="Default value for the field") + reducer: ReducerType | None = Field( + default=None, + description="How to combine multiple updates to this field", + ) + required: bool = Field(default=False, description="Whether this field is required") + + +class GraphStateSchema(BaseModel): + """Complete state schema for a graph agent.""" + + fields: dict[str, StateFieldSchema] = Field( + default_factory=dict, + description="Field definitions. 'messages' and 'execution_context' are always added automatically.", + ) + + +# --- Node Configuration Definitions --- + + +class LLMNodeConfig(BaseModel): + """Configuration for LLM reasoning nodes.""" + + prompt_template: str = Field(description="Jinja2 template for the prompt. Access state via {{ state.field_name }}") + output_key: str = Field(default="response", description="State key to store the LLM response") + model_override: str | None = Field(default=None, description="Override the agent's default model") + temperature_override: float | None = Field(default=None, ge=0.0, le=2.0, description="Override temperature") + max_tokens: int | None = Field(default=None, description="Maximum tokens in response") + tools_enabled: bool = Field(default=True, description="Whether to bind tools to this LLM call") + tool_filter: list[str] | None = Field(default=None, description="Specific tool names to enable (None = all)") + max_iterations: int = Field(default=10, ge=1, description="Maximum iterations for ReAct-style tool loops") + stop_sequences: list[str] | None = Field(default=None, description="Stop sequences for generation") + + +class ToolNodeConfig(BaseModel): + """Configuration for tool execution nodes.""" + + tool_name: str = Field(description="Name of the tool (MCP tool name or built-in)") + arguments_template: dict[str, str] = Field( + default_factory=dict, + description="Jinja2 templates for tool arguments. Keys are argument names.", + ) + output_key: str = Field(default="tool_result", description="State key to store the tool result") + timeout_seconds: int = Field(default=60, ge=1, le=600, description="Tool execution timeout") + retry_count: int = Field(default=0, ge=0, le=3, description="Number of retries on failure") + + +class EdgeCondition(BaseModel): + """Condition for conditional routing.""" + + state_key: str = Field(description="State key to evaluate") + operator: ConditionOperator = Field(description="Comparison operator") + value: Any = Field(default=None, description="Value to compare against (not needed for truthy/falsy)") + target: str = Field(description="Target node name if condition matches") + + +class RouterNodeConfig(BaseModel): + """Configuration for routing/branching decisions.""" + + strategy: Literal["condition", "llm", "state_check"] = Field( + default="condition", + description="Routing strategy: 'condition' (rule-based), 'llm' (AI decides), 'state_check'", + ) + conditions: list[EdgeCondition] = Field( + default_factory=list, + description="List of conditions to evaluate (for 'condition' strategy)", + ) + llm_prompt: str | None = Field( + default=None, + description="Prompt for LLM to decide route (for 'llm' strategy). Should output route name.", + ) + routes: list[str] = Field( + default_factory=list, + description="Valid route names the router can choose from", + ) + default_route: str = Field(default="END", description="Fallback route if no conditions match") + + +class SubagentNodeConfig(BaseModel): + """Configuration for invoking nested agents.""" + + agent_ref: str = Field( + description="Agent reference: UUID for user agents, key for system agents (e.g., 'deep_research')" + ) + input_mapping: dict[str, str] = Field( + default_factory=dict, + description="Map parent state keys to child input. Values are Jinja2 expressions.", + ) + output_mapping: dict[str, str] = Field( + default_factory=dict, + description="Map child output keys to parent state keys", + ) + inherit_context: bool = Field(default=True, description="Whether to pass execution context to the subagent") + inherit_tools: bool = Field(default=True, description="Whether subagent inherits parent's tools") + timeout_seconds: int = Field(default=300, ge=1, le=3600, description="Subagent execution timeout") + + +class ParallelNodeConfig(BaseModel): + """Configuration for parallel execution of multiple branches.""" + + branches: list[str] = Field(description="Node names to execute in parallel") + join_strategy: Literal["wait_all", "wait_any", "wait_n"] = Field( + default="wait_all", + description="How to wait for branches: all, any, or N branches", + ) + wait_count: int | None = Field(default=None, description="Number of branches to wait for (for 'wait_n' strategy)") + merge_strategy: Literal["merge_dicts", "list", "first", "custom"] = Field( + default="merge_dicts", + description="How to merge results from parallel branches", + ) + merge_key: str = Field(default="parallel_results", description="State key to store merged results") + timeout_seconds: int = Field(default=120, ge=1, le=600, description="Timeout for parallel execution") + + +class TransformNodeConfig(BaseModel): + """Configuration for data transformation.""" + + expression: str | None = Field( + default=None, + description="Python expression evaluated in restricted context. Use state['key'] to access values.", + ) + template: str | None = Field(default=None, description="Jinja2 template for complex transformations") + output_key: str = Field(description="State key to store the transformation result") + input_keys: list[str] = Field(default_factory=list, description="State keys required for this transformation") + + +class HumanNodeConfig(BaseModel): + """Configuration for human-in-the-loop checkpoints.""" + + prompt_template: str = Field(description="Message to display to the human") + input_type: Literal["text", "choice", "confirm", "form"] = Field( + default="text", description="Type of human input expected" + ) + choices: list[str] | None = Field(default=None, description="Available choices (for 'choice' input type)") + form_schema: dict[str, Any] | None = Field( + default=None, description="JSON Schema for form input (for 'form' input type)" + ) + output_key: str = Field(default="human_input", description="State key to store human response") + timeout_seconds: int | None = Field(default=None, description="Timeout for human response (None = no timeout)") + + +# --- Graph Node Definition --- + + +class GraphNodeConfig(BaseModel): + """Complete configuration for a single graph node.""" + + id: str = Field(description="Unique identifier within the graph") + name: str = Field(description="Human-readable display name") + type: NodeType = Field(description="Node type determining execution behavior") + description: str | None = Field(default=None, description="Description of what this node does") + + # Type-specific configurations (exactly one should be set based on type) + llm_config: LLMNodeConfig | None = None + tool_config: ToolNodeConfig | None = None + router_config: RouterNodeConfig | None = None + subagent_config: SubagentNodeConfig | None = None + parallel_config: ParallelNodeConfig | None = None + transform_config: TransformNodeConfig | None = None + human_config: HumanNodeConfig | None = None + + # UI positioning for visual editor + position: dict[str, float] | None = Field( + default=None, description="Position for visual editor: {'x': float, 'y': float}" + ) + + # Error handling + on_error: Literal["raise", "continue", "retry", "fallback"] = Field( + default="raise", description="Error handling strategy" + ) + retry_count: int = Field(default=0, ge=0, le=5, description="Retries before failing") + fallback_node: str | None = Field(default=None, description="Node to execute on error (for 'fallback' strategy)") + + # Metadata + tags: list[str] = Field(default_factory=list, description="Tags for categorization") + + +# --- Graph Edge Definition --- + + +class GraphEdgeConfig(BaseModel): + """Configuration for an edge between nodes.""" + + from_node: str = Field(description="Source node ID (use 'START' for entry point)") + to_node: str = Field(description="Target node ID (use 'END' for exit point)") + condition: EdgeCondition | None = Field( + default=None, description="Optional condition for this edge (None = unconditional)" + ) + label: str | None = Field(default=None, description="Label for UI display") + priority: int = Field(default=0, description="Priority when multiple edges match (higher = checked first)") + + +# --- Complete Graph Configuration --- + + +class GraphConfig(BaseModel): + """ + Complete graph configuration stored in agent.graph_config. + + This schema defines everything needed to build and execute a graph-based agent. + """ + + version: str = Field(default="1.0", description="Schema version for compatibility") + + # State definition + state_schema: GraphStateSchema = Field( + default_factory=GraphStateSchema, + description="State schema defining fields passed between nodes", + ) + + # Graph structure + nodes: list[GraphNodeConfig] = Field(description="All nodes in the graph") + edges: list[GraphEdgeConfig] = Field(description="Connections between nodes") + + # Entry and exit + entry_point: str = Field(description="Node ID to start execution (first node after START)") + exit_points: list[str] = Field( + default_factory=lambda: ["END"], + description="Node IDs that terminate execution", + ) + + # Reusable prompt templates + prompt_templates: dict[str, str] = Field( + default_factory=dict, + description="Named prompt templates that can be referenced via {{ prompt_templates.name }}", + ) + + # Component references for importing reusable components + imported_components: list[str] = Field( + default_factory=list, + description="Component registry keys to import (e.g., 'system:deep_research:query_analyzer')", + ) + + # Metadata + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional metadata (author, description, etc.)", + ) + + # Execution settings + max_execution_time_seconds: int = Field(default=300, ge=1, le=3600, description="Maximum total execution time") + enable_checkpoints: bool = Field(default=True, description="Whether to save checkpoints for resumption") + + +# --- Helper Functions --- + + +def validate_graph_config(config: GraphConfig) -> list[str]: + """ + Validate a graph configuration for structural correctness. + + Returns a list of validation errors (empty if valid). + """ + errors: list[str] = [] + + # Check that entry_point exists + node_ids = {node.id for node in config.nodes} + if config.entry_point not in node_ids: + errors.append(f"Entry point '{config.entry_point}' not found in nodes") + + # Check that all edge references are valid + valid_refs = node_ids | {"START", "END"} + for edge in config.edges: + if edge.from_node not in valid_refs: + errors.append(f"Edge from_node '{edge.from_node}' not found") + if edge.to_node not in valid_refs: + errors.append(f"Edge to_node '{edge.to_node}' not found") + + # Check that each node has the correct config for its type + for node in config.nodes: + match node.type: + case NodeType.LLM: + if not node.llm_config: + errors.append(f"Node '{node.id}' is type LLM but missing llm_config") + case NodeType.TOOL: + if not node.tool_config: + errors.append(f"Node '{node.id}' is type TOOL but missing tool_config") + case NodeType.ROUTER: + if not node.router_config: + errors.append(f"Node '{node.id}' is type ROUTER but missing router_config") + case NodeType.SUBAGENT: + if not node.subagent_config: + errors.append(f"Node '{node.id}' is type SUBAGENT but missing subagent_config") + case NodeType.PARALLEL: + if not node.parallel_config: + errors.append(f"Node '{node.id}' is type PARALLEL but missing parallel_config") + case NodeType.TRANSFORM: + if not node.transform_config: + errors.append(f"Node '{node.id}' is type TRANSFORM but missing transform_config") + case NodeType.HUMAN: + if not node.human_config: + errors.append(f"Node '{node.id}' is type HUMAN but missing human_config") + + # Check for unreachable nodes + reachable = {"START"} + changed = True + while changed: + changed = False + for edge in config.edges: + if edge.from_node in reachable and edge.to_node not in reachable: + reachable.add(edge.to_node) + changed = True + + unreachable = node_ids - reachable + if unreachable: + errors.append(f"Unreachable nodes: {unreachable}") + + # Check for nodes that don't lead to END + leads_to_end: set[str] = {"END"} + changed = True + while changed: + changed = False + for edge in config.edges: + if edge.to_node in leads_to_end and edge.from_node not in leads_to_end: + leads_to_end.add(edge.from_node) + changed = True + + dead_ends = node_ids - leads_to_end + if dead_ends: + errors.append(f"Nodes that don't lead to END: {dead_ends}") + + return errors + + +# Export commonly used types +__all__ = [ + "NodeType", + "ConditionOperator", + "ReducerType", + "StateFieldSchema", + "GraphStateSchema", + "LLMNodeConfig", + "ToolNodeConfig", + "RouterNodeConfig", + "SubagentNodeConfig", + "ParallelNodeConfig", + "TransformNodeConfig", + "HumanNodeConfig", + "GraphNodeConfig", + "EdgeCondition", + "GraphEdgeConfig", + "GraphConfig", + "validate_graph_config", +] diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index 63d2d70f..965a87e2 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Any, Dict, List +from typing import Any from uuid import UUID import redis.asyncio as redis @@ -55,10 +55,10 @@ def extract_content_text(content: Any) -> str: if isinstance(content, str): return content if isinstance(content, list): - text_parts: List[str] = [] + text_parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": - text_parts.append(str(item.get("text", ""))) # pyright: ignore + text_parts.append(str(item.get("text", ""))) # pyright: ignore[reportUnknownArgumentType] return "".join(text_parts) return str(content) @@ -70,7 +70,7 @@ def process_chat_message( user_id_str: str, auth_provider: str, message_text: str, - context: Dict[str, Any] | None, + context: dict[str, Any] | None, pre_deducted_amount: float, access_token: str | None = None, ) -> None: @@ -123,7 +123,7 @@ async def _process_chat_message_async( user_id_str: str, auth_provider: str, message_text: str, - context: Dict[str, Any] | None, + context: dict[str, Any] | None, pre_deducted_amount: float, access_token: str | None, ) -> None: @@ -166,7 +166,7 @@ async def _process_chat_message_async( ai_message_obj: Message | None = None full_content = "" full_thinking_content = "" # Track thinking content for persistence - citations_data: List[CitationData] = [] + citations_data: list[CitationData] = [] generated_files_count = 0 input_tokens: int = 0 @@ -282,7 +282,7 @@ async def _process_chat_message_async( elif event_type == ChatEventType.SEARCH_CITATIONS: citations = stream_event["data"].get("citations", []) if citations: - citations_data.extend(citations) # type: ignore + citations_data.extend(citations) await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.GENERATED_FILES: @@ -343,7 +343,7 @@ async def _process_chat_message_async( if citations_data: try: citation_repo = CitationRepository(db) - citation_creates: List[CitationCreate] = [] + citation_creates: list[CitationCreate] = [] for citation in citations_data: citation_create = CitationCreate( message_id=ai_message_obj.id, From fd57adc2eb4088179e7a0851c4b801404a17bce9 Mon Sep 17 00:00:00 2001 From: xinquiry Date: Wed, 7 Jan 2026 23:54:01 +0800 Subject: [PATCH 2/4] feat: implement the model and schemas --- service/app/models/__init__.py | 4 - service/app/models/agent.py | 17 +- service/app/models/graph.py | 159 ----- service/app/models/message.py | 7 +- service/app/repos/graph.py | 555 ------------------ service/app/schemas/agent.py | 11 - ...gent_events.py => agent_event_payloads.py} | 0 service/app/schemas/chat_event_payloads.py | 554 +++++++++++++++++ service/app/schemas/chat_event_types.py | 538 ++++------------- service/app/schemas/chat_events.py | 156 ----- .../versions/a05b7d38a4a0_add_graph_config.py | 33 ++ .../a7b3c8e51f92_drop_legacy_graph_tables.py | 110 ++++ .../f0d7b93430e1_add_agent_metadata.py | 33 ++ service/pyproject.toml | 13 +- service/tests/integration/test_integration.py | 2 +- .../unit/handler/mcp/test_knowledge_limits.py | 2 +- .../unit/test_core/test_thinking_events.py | 10 +- service/uv.lock | 71 +++ 18 files changed, 942 insertions(+), 1333 deletions(-) delete mode 100644 service/app/models/graph.py delete mode 100644 service/app/repos/graph.py delete mode 100644 service/app/schemas/agent.py rename service/app/schemas/{agent_events.py => agent_event_payloads.py} (100%) create mode 100644 service/app/schemas/chat_event_payloads.py delete mode 100644 service/app/schemas/chat_events.py create mode 100644 service/migrations/versions/a05b7d38a4a0_add_graph_config.py create mode 100644 service/migrations/versions/a7b3c8e51f92_drop_legacy_graph_tables.py create mode 100644 service/migrations/versions/f0d7b93430e1_add_agent_metadata.py diff --git a/service/app/models/__init__.py b/service/app/models/__init__.py index 5bfb1689..1a50ce2f 100644 --- a/service/app/models/__init__.py +++ b/service/app/models/__init__.py @@ -16,7 +16,6 @@ from .file import File, FileCreate, FileRead, FileReadWithUrl, FileUpdate from .file_knowledge_set_link import FileKnowledgeSetLink, FileKnowledgeSetLinkCreate, FileKnowledgeSetLinkRead from .folder import Folder, FolderCreate, FolderRead, FolderUpdate -from .graph import GraphAgent, GraphEdge, GraphNode from .knowledge_set import ( KnowledgeSet, KnowledgeSetCreate, @@ -93,9 +92,6 @@ "ToolVersion", "ToolFunction", "Topic", - "GraphAgent", - "GraphNode", - "GraphEdge", "TopicRead", "TopicReadWithMessages", "SmitheryServersCache", diff --git a/service/app/models/agent.py b/service/app/models/agent.py index f70265a7..d7aadd7f 100644 --- a/service/app/models/agent.py +++ b/service/app/models/agent.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from enum import StrEnum -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 import sqlalchemy as sa @@ -16,13 +16,6 @@ class AgentScope(StrEnum): USER = "user" -class AgentType(StrEnum): - """Type of agent determining execution strategy.""" - - GRAPH = "graph" # JSON-configured graph agent (user-customizable) - SYSTEM = "system" # Python-coded system agent (e.g., react, deep_research) - - class AgentBase(SQLModel): scope: AgentScope = Field( sa_column=sa.Column( @@ -51,6 +44,10 @@ class AgentBase(SQLModel): default=None, description="Version of the marketplace listing this agent was forked from", ) + # JSON configuration for graph-based agents + # If None or empty, fallback to the default react system agent + # Can include metadata.system_agent_key to use a specific system agent as base + graph_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) class Agent(AgentBase, table=True): @@ -78,6 +75,7 @@ class AgentCreate(SQLModel): provider_id: UUID | None = Field(default=None, index=True) knowledge_set_id: UUID | None = Field(default=None) mcp_server_ids: list[UUID] = [] + graph_config: dict[str, Any] | None = None class AgentRead(AgentBase): @@ -86,7 +84,7 @@ class AgentRead(AgentBase): class AgentReadWithDetails(AgentRead): - mcp_servers: List["McpServer"] = [] + mcp_servers: list["McpServer"] = [] class AgentUpdate(SQLModel): @@ -101,3 +99,4 @@ class AgentUpdate(SQLModel): provider_id: UUID | None = None knowledge_set_id: UUID | None = None mcp_server_ids: list[UUID] | None = None + graph_config: dict[str, Any] | None = None diff --git a/service/app/models/graph.py b/service/app/models/graph.py deleted file mode 100644 index d97d596d..00000000 --- a/service/app/models/graph.py +++ /dev/null @@ -1,159 +0,0 @@ -from datetime import datetime, timezone -from typing import Any -from uuid import UUID, uuid4 - -from sqlalchemy import TIMESTAMP -from sqlmodel import JSON, Column, Field, SQLModel - - -class GraphAgentBase(SQLModel): - name: str = Field(max_length=100) - description: str | None = Field(default=None, max_length=500) - state_schema: dict[str, Any] = Field(sa_column=Column(JSON)) - is_active: bool = Field(default=True) - parent_agent_id: UUID | None = Field(default=None, index=True) - user_id: str = Field(index=True) - is_published: bool = Field(default=False, index=True) - is_official: bool = Field(default=False, index=True) - - -class GraphAgent(GraphAgentBase, table=True): - id: UUID = Field(default_factory=uuid4, primary_key=True, index=True) - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False), - ) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False, onupdate=lambda: datetime.now(timezone.utc)), - ) - - -class GraphAgentCreate(SQLModel): - name: str - description: str | None = None - state_schema: dict[str, Any] - parent_agent_id: UUID | None = None - is_published: bool = False - is_official: bool = False - - -class GraphAgentRead(GraphAgentBase): - id: UUID - created_at: datetime - updated_at: datetime - - -class GraphAgentUpdate(SQLModel): - name: str | None = None - description: str | None = None - state_schema: dict[str, Any] | None = None - is_active: bool | None = None - is_published: bool | None = None - is_official: bool | None = None - - -class GraphNodeBase(SQLModel): - name: str = Field(max_length=100) - node_type: str = Field(max_length=50) # 'llm', 'tool', 'router', 'subagent' - config: dict[str, Any] = Field(sa_column=Column(JSON)) - graph_agent_id: UUID = Field(index=True) - position_x: float | None = None - position_y: float | None = None - - -class GraphNode(GraphNodeBase, table=True): - id: UUID = Field(default_factory=uuid4, primary_key=True, index=True) - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False), - ) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False, onupdate=lambda: datetime.now(timezone.utc)), - ) - - -class GraphNodeCreate(SQLModel): - name: str - node_type: str - config: dict[str, Any] - graph_agent_id: UUID - position_x: float | None = None - position_y: float | None = None - - -class GraphNodeRead(GraphNodeBase): - id: UUID - created_at: datetime - updated_at: datetime - - -class GraphNodeUpdate(SQLModel): - name: str | None = None - node_type: str | None = None - config: dict[str, Any] | None = None - position_x: float | None = None - position_y: float | None = None - - -class GraphEdgeBase(SQLModel): - from_node_id: UUID = Field(index=True) - to_node_id: UUID = Field(index=True) - condition: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) # Conditional routing logic - graph_agent_id: UUID = Field(index=True) - label: str | None = Field(default=None, max_length=100) # For UI display - - -class GraphEdge(GraphEdgeBase, table=True): - id: UUID = Field(default_factory=uuid4, primary_key=True, index=True) - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False), - ) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column=Column(TIMESTAMP(timezone=True), nullable=False, onupdate=lambda: datetime.now(timezone.utc)), - ) - - -class GraphEdgeCreate(SQLModel): - from_node_id: UUID - to_node_id: UUID - condition: dict[str, Any] | None = None - graph_agent_id: UUID - label: str | None = None - - -class GraphEdgeRead(GraphEdgeBase): - id: UUID - created_at: datetime - updated_at: datetime - - -class GraphEdgeUpdate(SQLModel): - condition: dict[str, Any] | None = None - label: str | None = None - - -# Composite models for complex operations -class GraphAgentWithGraph(GraphAgentRead): - nodes: list[GraphNodeRead] = [] - edges: list[GraphEdgeRead] = [] - - -class GraphAgentCreateWithGraph(SQLModel): - agent: GraphAgentCreate - nodes: list[GraphNodeCreate] = [] - edges: list[GraphEdgeCreate] = [] - - -class GraphExecutionResult(SQLModel): - """Result from executing a graph agent""" - - agent_id: UUID - final_state: dict[str, Any] - execution_steps: list[dict[str, Any]] # Step-by-step execution log - success: bool - error_message: str | None = None - execution_time_ms: int diff --git a/service/app/models/message.py b/service/app/models/message.py index 0a6be71d..03ebd1fb 100644 --- a/service/app/models/message.py +++ b/service/app/models/message.py @@ -1,9 +1,9 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from sqlalchemy import TIMESTAMP -from sqlmodel import Column, Field, SQLModel +from sqlmodel import JSON, Column, Field, SQLModel if TYPE_CHECKING: from .citation import CitationRead @@ -18,6 +18,8 @@ class MessageBase(SQLModel): topic_id: UUID = Field(index=True) # Thinking/reasoning content from models like Claude, DeepSeek R1, Gemini 3 thinking_content: str | None = None + # Agent metadata for storing additional context (agent state, etc.) + agent_metadata: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) class Message(MessageBase, table=True): @@ -66,3 +68,4 @@ class MessageUpdate(SQLModel): role: str | None = None content: str | None = None thinking_content: str | None = None + agent_metadata: dict[str, Any] | None = None diff --git a/service/app/repos/graph.py b/service/app/repos/graph.py deleted file mode 100644 index 835c1734..00000000 --- a/service/app/repos/graph.py +++ /dev/null @@ -1,555 +0,0 @@ -import logging -from typing import Any -from uuid import UUID - -from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession - -from app.models.graph import ( - GraphAgent, - GraphAgentCreate, - GraphAgentCreateWithGraph, - GraphAgentUpdate, - GraphAgentWithGraph, - GraphEdge, - GraphEdgeCreate, - GraphEdgeRead, - GraphEdgeUpdate, - GraphExecutionResult, - GraphNode, - GraphNodeCreate, - GraphNodeRead, - GraphNodeUpdate, -) - -logger = logging.getLogger(__name__) - - -class GraphRepository: - def __init__(self, db: AsyncSession) -> None: - self.db = db - - # GraphAgent operations - async def get_graph_agent_by_id(self, agent_id: UUID) -> GraphAgent | None: - """ - Fetches a graph agent by its ID. - - Args: - agent_id: The UUID of the graph agent to fetch. - - Returns: - The GraphAgent, or None if not found. - """ - logger.debug(f"Fetching graph agent with id: {agent_id}") - return await self.db.get(GraphAgent, agent_id) - - async def get_graph_agents_by_user(self, user_id: str) -> list[GraphAgent]: - """ - Fetches all graph agents for a given user. - - Args: - user_id: The user ID. - - Returns: - List of GraphAgent instances. - """ - logger.debug(f"Fetching graph agents for user_id: {user_id}") - statement = select(GraphAgent).where(GraphAgent.user_id == user_id) - result = await self.db.exec(statement) - return list(result.all()) - - async def get_all_published_graph_agents(self) -> list[GraphAgent]: - """ - Fetches all published graph agents from all users. - - This is used for the Explorer view where users can discover - published graphs from other users. - - Returns: - List of published GraphAgent instances, ordered by creation time (newest first). - """ - logger.debug("Fetching all published graph agents") - statement = ( - select(GraphAgent).where(col(GraphAgent.is_published).is_(True)).order_by(col(GraphAgent.created_at).desc()) - ) - result = await self.db.exec(statement) - return list(result.all()) - - async def get_published_graph_agents_by_user(self, user_id: str) -> list[GraphAgent]: - """ - Fetches all published graph agents for a given user. - - Args: - user_id: The user ID. - - Returns: - List of GraphAgent instances. - """ - logger.debug(f"Fetching published graph agents for user_id: {user_id}") - statement = ( - select(GraphAgent).where(GraphAgent.user_id == user_id).where(col(GraphAgent.is_published).is_(True)) - ) - result = await self.db.exec(statement) - return list(result.all()) - - async def get_official_graph_agents(self) -> list[GraphAgent]: - """ - Fetches all official graph agents (is_official=True). - - Official agents are accessible to all users and typically represent - built-in or verified agents from the platform. - - Returns: - List of official GraphAgent instances, ordered by creation time (newest first). - """ - logger.debug("Fetching all official graph agents") - statement = ( - select(GraphAgent).where(col(GraphAgent.is_official).is_(True)).order_by(col(GraphAgent.created_at).desc()) - ) - result = await self.db.exec(statement) - return list(result.all()) - - async def get_published_official_agents(self) -> list[GraphAgent]: - """ - Fetches all official graph agents that are also published. - - Returns official agents where both is_official=True and is_published=True. - Useful for showing curated official agents in public-facing interfaces. - - Returns: - List of official and published GraphAgent instances, ordered by creation time (newest first). - """ - logger.debug("Fetching published official graph agents") - statement = ( - select(GraphAgent) - .where(col(GraphAgent.is_official).is_(True)) - .where(col(GraphAgent.is_published).is_(True)) - .order_by(col(GraphAgent.created_at).desc()) - ) - result = await self.db.exec(statement) - return list(result.all()) - - async def get_graph_agent_with_graph(self, agent_id: UUID) -> GraphAgentWithGraph | None: - """ - Fetches a graph agent by its ID with all nodes and edges loaded. - - Args: - agent_id: The UUID of the graph agent to fetch. - - Returns: - The GraphAgentWithGraph with nodes and edges populated, or None if not found. - """ - logger.debug(f"Fetching graph agent with full graph for agent_id: {agent_id}") - - # Get the agent - agent = await self.db.get(GraphAgent, agent_id) - if not agent: - return None - - # Get all nodes for this agent - nodes_statement = select(GraphNode).where(GraphNode.graph_agent_id == agent_id) - nodes_result = await self.db.exec(nodes_statement) - nodes = list(nodes_result.all()) - - # Get all edges for this agent - edges_statement = select(GraphEdge).where(GraphEdge.graph_agent_id == agent_id) - edges_result = await self.db.exec(edges_statement) - edges = list(edges_result.all()) - - # Convert to GraphAgentWithGraph - agent_dict = agent.model_dump() - # Convert nodes and edges to Read models - node_reads = [GraphNodeRead(**node.model_dump()) for node in nodes] - edge_reads = [GraphEdgeRead(**edge.model_dump()) for edge in edges] - agent_with_graph = GraphAgentWithGraph(**agent_dict, nodes=node_reads, edges=edge_reads) - - return agent_with_graph - - async def create_graph_agent(self, agent_data: GraphAgentCreate, user_id: str) -> GraphAgent: - """ - Creates a new graph agent. - This function does NOT commit the transaction, but it does flush the session. - - Args: - agent_data: The Pydantic model containing the data for the new graph agent. - user_id: The user ID (from authentication). - - Returns: - The newly created GraphAgent instance. - """ - logger.debug(f"Creating new graph agent for user_id: {user_id}") - - # Create agent - agent_dict = agent_data.model_dump() - agent_dict["user_id"] = user_id - agent = GraphAgent(**agent_dict) - - self.db.add(agent) - await self.db.flush() - await self.db.refresh(agent) - - return agent - - async def create_graph_agent_with_graph( - self, agent_data: GraphAgentCreateWithGraph, user_id: str - ) -> GraphAgentWithGraph: - """ - Creates a new graph agent with nodes and edges in a single transaction. - This function does NOT commit the transaction, but it does flush the session. - - Args: - agent_data: The composite model containing agent, nodes, and edges data. - user_id: The user ID (from authentication). - - Returns: - The newly created GraphAgentWithGraph instance. - """ - logger.debug(f"Creating new graph agent with full graph for user_id: {user_id}") - - # Create the agent first - agent = await self.create_graph_agent(agent_data.agent, user_id) - - # Create nodes - created_nodes: list[GraphNode] = [] - node_id_mapping: dict[str, UUID] = {} # Map node names to IDs for edge creation - - for node_data in agent_data.nodes: - node_dict = node_data.model_dump() - node_dict["graph_agent_id"] = agent.id - node = GraphNode(**node_dict) - self.db.add(node) - created_nodes.append(node) - # Store mapping for edge creation (assuming node names are unique within a graph) - node_id_mapping[node.name] = node.id - - await self.db.flush() - - # Refresh nodes to get IDs - for node in created_nodes: - await self.db.refresh(node) - # Update mapping with actual UUIDs - node_id_mapping[node.name] = node.id - - # Create edges - created_edges: list[GraphEdge] = [] - for edge_data in agent_data.edges: - edge_dict = edge_data.model_dump() - edge_dict["graph_agent_id"] = agent.id - edge = GraphEdge(**edge_dict) - self.db.add(edge) - created_edges.append(edge) - - await self.db.flush() - - # Refresh edges - for edge in created_edges: - await self.db.refresh(edge) - - # Return composite model - agent_dict = agent.model_dump() - # Convert nodes and edges to Read models - node_reads = [GraphNodeRead(**node.model_dump()) for node in created_nodes] - edge_reads = [GraphEdgeRead(**edge.model_dump()) for edge in created_edges] - return GraphAgentWithGraph(**agent_dict, nodes=node_reads, edges=edge_reads) - - async def update_graph_agent(self, agent_id: UUID, agent_data: GraphAgentUpdate) -> GraphAgent | None: - """ - Updates an existing graph agent. - This function does NOT commit the transaction. - - Args: - agent_id: The UUID of the graph agent to update. - agent_data: The Pydantic model containing the update data. - - Returns: - The updated GraphAgent instance, or None if not found. - """ - logger.debug(f"Updating graph agent with id: {agent_id}") - agent = await self.db.get(GraphAgent, agent_id) - if not agent: - return None - - # Use safe update pattern - update_data = agent_data.model_dump(exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - if hasattr(agent, key): - setattr(agent, key, value) - - self.db.add(agent) - await self.db.flush() - await self.db.refresh(agent) - return agent - - async def delete_graph_agent(self, agent_id: UUID) -> bool: - """ - Deletes a graph agent and all its nodes and edges. - This function does NOT commit the transaction. - - Args: - agent_id: The UUID of the graph agent to delete. - - Returns: - True if the agent was deleted, False if not found. - """ - logger.debug(f"Deleting graph agent with id: {agent_id}") - - # Delete all edges first - edges_statement = select(GraphEdge).where(GraphEdge.graph_agent_id == agent_id) - edges_result = await self.db.exec(edges_statement) - edges = list(edges_result.all()) - for edge in edges: - await self.db.delete(edge) - - # Delete all nodes - nodes_statement = select(GraphNode).where(GraphNode.graph_agent_id == agent_id) - nodes_result = await self.db.exec(nodes_statement) - nodes = list(nodes_result.all()) - for node in nodes: - await self.db.delete(node) - - # Delete the agent - agent = await self.db.get(GraphAgent, agent_id) - if not agent: - return False - - await self.db.delete(agent) - await self.db.flush() - return True - - # GraphNode operations - async def get_nodes_by_agent(self, agent_id: UUID) -> list[GraphNode]: - """ - Fetches all nodes for a given graph agent. - - Args: - agent_id: The UUID of the graph agent. - - Returns: - List of GraphNode instances. - """ - logger.debug(f"Fetching nodes for graph agent: {agent_id}") - statement = select(GraphNode).where(GraphNode.graph_agent_id == agent_id) - result = await self.db.exec(statement) - return list(result.all()) - - async def create_node(self, node_data: GraphNodeCreate) -> GraphNode: - """ - Creates a new node. - This function does NOT commit the transaction. - - Args: - node_data: The Pydantic model containing the node data. - - Returns: - The newly created GraphNode instance. - """ - logger.debug(f"Creating new node for graph agent: {node_data.graph_agent_id}") - node_dict = node_data.model_dump() - node = GraphNode(**node_dict) - self.db.add(node) - await self.db.flush() - await self.db.refresh(node) - return node - - async def update_node(self, node_id: UUID, node_data: GraphNodeUpdate) -> GraphNode | None: - """ - Updates an existing node. - This function does NOT commit the transaction. - - Args: - node_id: The UUID of the node to update. - node_data: The Pydantic model containing the update data. - - Returns: - The updated GraphNode instance, or None if not found. - """ - logger.debug(f"Updating node with id: {node_id}") - node = await self.db.get(GraphNode, node_id) - if not node: - return None - - update_data = node_data.model_dump(exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - if hasattr(node, key): - setattr(node, key, value) - - self.db.add(node) - await self.db.flush() - await self.db.refresh(node) - return node - - async def delete_node(self, node_id: UUID) -> bool: - """ - Deletes a node and all edges connected to it. - This function does NOT commit the transaction. - - Args: - node_id: The UUID of the node to delete. - - Returns: - True if the node was deleted, False if not found. - """ - logger.debug(f"Deleting node with id: {node_id}") - - # Delete all edges connected to this node - edges_statement = select(GraphEdge).where( - (GraphEdge.from_node_id == node_id) | (GraphEdge.to_node_id == node_id) - ) - edges_result = await self.db.exec(edges_statement) - edges = list(edges_result.all()) - for edge in edges: - await self.db.delete(edge) - - # Delete the node - node = await self.db.get(GraphNode, node_id) - if not node: - return False - - await self.db.delete(node) - await self.db.flush() - return True - - # GraphEdge operations - async def get_edges_by_agent(self, agent_id: UUID) -> list[GraphEdge]: - """ - Fetches all edges for a given graph agent. - - Args: - agent_id: The UUID of the graph agent. - - Returns: - List of GraphEdge instances. - """ - logger.debug(f"Fetching edges for graph agent: {agent_id}") - statement = select(GraphEdge).where(GraphEdge.graph_agent_id == agent_id) - result = await self.db.exec(statement) - return list(result.all()) - - async def create_edge(self, edge_data: GraphEdgeCreate) -> GraphEdge: - """ - Creates a new edge. - This function does NOT commit the transaction. - - Args: - edge_data: The Pydantic model containing the edge data. - - Returns: - The newly created GraphEdge instance. - """ - logger.debug(f"Creating new edge for graph agent: {edge_data.graph_agent_id}") - edge_dict = edge_data.model_dump() - edge = GraphEdge(**edge_dict) - self.db.add(edge) - await self.db.flush() - await self.db.refresh(edge) - return edge - - async def update_edge(self, edge_id: UUID, edge_data: GraphEdgeUpdate) -> GraphEdge | None: - """ - Updates an existing edge. - This function does NOT commit the transaction. - - Args: - edge_id: The UUID of the edge to update. - edge_data: The Pydantic model containing the update data. - - Returns: - The updated GraphEdge instance, or None if not found. - """ - logger.debug(f"Updating edge with id: {edge_id}") - edge = await self.db.get(GraphEdge, edge_id) - if not edge: - return None - - update_data = edge_data.model_dump(exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - if hasattr(edge, key): - setattr(edge, key, value) - - self.db.add(edge) - await self.db.flush() - await self.db.refresh(edge) - return edge - - async def delete_edge(self, edge_id: UUID) -> bool: - """ - Deletes an edge. - This function does NOT commit the transaction. - - Args: - edge_id: The UUID of the edge to delete. - - Returns: - True if the edge was deleted, False if not found. - """ - logger.debug(f"Deleting edge with id: {edge_id}") - edge = await self.db.get(GraphEdge, edge_id) - if not edge: - return False - - await self.db.delete(edge) - await self.db.flush() - return True - - # Graph validation methods - async def validate_graph_structure(self, agent_id: UUID) -> dict[str, Any]: - """ - Validates the graph structure for connectivity and cycles. - - Args: - agent_id: The UUID of the graph agent to validate. - - Returns: - Dict containing validation results. - """ - logger.debug(f"Validating graph structure for agent: {agent_id}") - - nodes = await self.get_nodes_by_agent(agent_id) - edges = await self.get_edges_by_agent(agent_id) - - if not nodes: - return {"valid": False, "errors": ["Graph has no nodes"]} - - node_ids = {node.id for node in nodes} - validation_result = { - "valid": True, - "errors": [], - "warnings": [], - "node_count": len(nodes), - "edge_count": len(edges), - } - - # Check for invalid edge references - for edge in edges: - if edge.from_node_id not in node_ids: - validation_result["errors"].append(f"Edge {edge.id} references invalid from_node_id") - validation_result["valid"] = False - if edge.to_node_id not in node_ids: - validation_result["errors"].append(f"Edge {edge.id} references invalid to_node_id") - validation_result["valid"] = False - - # Check for isolated nodes (nodes with no connections) - connected_nodes = set() - for edge in edges: - connected_nodes.add(edge.from_node_id) - connected_nodes.add(edge.to_node_id) - - isolated_nodes = node_ids - connected_nodes - if isolated_nodes: - validation_result["warnings"].append(f"Found {len(isolated_nodes)} isolated nodes") - - return validation_result - - # Execution history (for future implementation) - async def save_execution_result(self, result: GraphExecutionResult) -> None: - """ - Saves execution result for audit/debugging purposes. - Future implementation could store this in a separate execution_history table. - - Args: - result: The execution result to save. - """ - logger.debug(f"Execution result for agent {result.agent_id}: {result.success}") - # For now, just log the result - # Future: store in execution_history table - pass diff --git a/service/app/schemas/agent.py b/service/app/schemas/agent.py deleted file mode 100644 index 895c3887..00000000 --- a/service/app/schemas/agent.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import StrEnum - - -class AgentScope(StrEnum): - OFFICIAL = "official" - USER = "user" - - -class AgentVisibility(StrEnum): - PUBLIC = "public" - PRIVATE = "private" diff --git a/service/app/schemas/agent_events.py b/service/app/schemas/agent_event_payloads.py similarity index 100% rename from service/app/schemas/agent_events.py rename to service/app/schemas/agent_event_payloads.py diff --git a/service/app/schemas/chat_event_payloads.py b/service/app/schemas/chat_event_payloads.py new file mode 100644 index 00000000..757699e6 --- /dev/null +++ b/service/app/schemas/chat_event_payloads.py @@ -0,0 +1,554 @@ +""" +Typed data structures for chat streaming events. + +This module provides TypedDict classes that define the exact shape of data +payloads for each chat event type. Use these instead of dict[str, Any] for +better type safety and IDE autocompletion. + +Example: + from app.schemas.chat_event_payloads import StreamingChunkEvent, StreamingChunkData + + event: StreamingChunkEvent = { + "type": ChatEventType.STREAMING_CHUNK, + "data": {"id": "stream_123", "content": "Hello"} + } +""" + +from typing import Any, TypedDict + +from typing_extensions import Literal, NotRequired + +from app.schemas.agent_event_payloads import ( + AgentEndData, + AgentErrorData, + AgentExecutionContext, + AgentStartData, + HumanInputReceivedData, + HumanInputRequiredData, + IterationEndData, + IterationStartData, + NodeEndData, + NodeStartData, + PhaseEndData, + PhaseStartData, + ProgressUpdateData, + StateUpdateData, + SubagentEndData, + SubagentStartData, +) +from app.schemas.chat_event_types import ChatEventType + +# ============================================================================= +# Data Payloads (the "data" field of each event) +# ============================================================================= + + +class StreamingStartData(TypedDict): + """Data payload for STREAMING_START event.""" + + id: str + + +class StreamingChunkData(TypedDict): + """Data payload for STREAMING_CHUNK event.""" + + id: str + content: str + + +class StreamingEndData(TypedDict): + """Data payload for STREAMING_END event.""" + + id: str + created_at: float + content: NotRequired[str] # Optional content for final streaming result + agent_state: NotRequired[dict[str, Any]] # Agent state metadata for persistence + + +class ProcessingData(TypedDict): + """Data payload for PROCESSING event.""" + + status: str + + +class LoadingData(TypedDict): + """Data payload for LOADING event.""" + + message: str + + +class ErrorData(TypedDict): + """Data payload for ERROR event.""" + + error: str + + +class ToolCallRequestData(TypedDict): + """Data payload for TOOL_CALL_REQUEST event.""" + + id: str + name: str + description: str + arguments: dict[str, Any] + status: str + timestamp: float + + +class ToolCallResponseData(TypedDict): + """Data payload for TOOL_CALL_RESPONSE event.""" + + toolCallId: str + status: str + result: str + error: NotRequired[str] + + +class TokenUsageData(TypedDict): + """Data payload for TOKEN_USAGE event.""" + + input_tokens: int + output_tokens: int + total_tokens: int + + +class CitationData(TypedDict): + """Single citation entry within SearchCitationsData.""" + + url: str + title: str | None + cited_text: str | None + start_index: int | None + end_index: int | None + search_queries: NotRequired[list[str]] + + +class SearchCitationsData(TypedDict): + """Data payload for SEARCH_CITATIONS event.""" + + citations: list[CitationData] + + +class GeneratedFileInfo(TypedDict): + """Single file entry within GeneratedFilesData.""" + + id: str + name: str + type: str + size: int + category: str + download_url: str + + +class GeneratedFilesData(TypedDict): + """Data payload for GENERATED_FILES event.""" + + files: list[GeneratedFileInfo] + + +class MessageSavedData(TypedDict): + """Data payload for MESSAGE_SAVED event.""" + + stream_id: str + db_id: str + created_at: str | None + + +class MessageData(TypedDict): + """Data payload for MESSAGE event (non-streaming response).""" + + id: str + content: str + + +class InsufficientBalanceData(TypedDict): + """Data payload for insufficient_balance event.""" + + error_code: str + message: str + message_cn: NotRequired[str] + details: NotRequired[dict[str, Any]] + action_required: str + + +class ThinkingStartData(TypedDict): + """Data payload for THINKING_START event.""" + + id: str + + +class ThinkingChunkData(TypedDict): + """Data payload for THINKING_CHUNK event.""" + + id: str + content: str + + +class ThinkingEndData(TypedDict): + """Data payload for THINKING_END event.""" + + id: str + + +# ============================================================================= +# Full Event Structures (type + data) +# ============================================================================= + + +class StreamingStartEvent(TypedDict): + """Full event structure for streaming start.""" + + type: Literal[ChatEventType.STREAMING_START] + data: StreamingStartData + + +class StreamingChunkEvent(TypedDict): + """Full event structure for streaming chunk.""" + + type: Literal[ChatEventType.STREAMING_CHUNK] + data: StreamingChunkData + + +class StreamingEndEvent(TypedDict): + """Full event structure for streaming end.""" + + type: Literal[ChatEventType.STREAMING_END] + data: StreamingEndData + + +class ProcessingEvent(TypedDict): + """Full event structure for processing status.""" + + type: Literal[ChatEventType.PROCESSING] + data: ProcessingData + + +class LoadingEvent(TypedDict): + """Full event structure for loading status.""" + + type: Literal[ChatEventType.LOADING] + data: LoadingData + + +class ErrorEvent(TypedDict): + """Full event structure for errors.""" + + type: Literal[ChatEventType.ERROR] + data: ErrorData + + +class ToolCallRequestEvent(TypedDict): + """Full event structure for tool call request.""" + + type: Literal[ChatEventType.TOOL_CALL_REQUEST] + data: ToolCallRequestData + + +class ToolCallResponseEvent(TypedDict): + """Full event structure for tool call response.""" + + type: Literal[ChatEventType.TOOL_CALL_RESPONSE] + data: ToolCallResponseData + + +class TokenUsageEvent(TypedDict): + """Full event structure for token usage.""" + + type: Literal[ChatEventType.TOKEN_USAGE] + data: TokenUsageData + + +class SearchCitationsEvent(TypedDict): + """Full event structure for search citations.""" + + type: Literal[ChatEventType.SEARCH_CITATIONS] + data: SearchCitationsData + + +class GeneratedFilesEvent(TypedDict): + """Full event structure for generated files.""" + + type: Literal[ChatEventType.GENERATED_FILES] + data: GeneratedFilesData + + +class MessageSavedEvent(TypedDict): + """Full event structure for message saved confirmation.""" + + type: Literal[ChatEventType.MESSAGE_SAVED] + data: MessageSavedData + + +class MessageEvent(TypedDict): + """Full event structure for non-streaming message.""" + + type: Literal[ChatEventType.MESSAGE] + data: MessageData + + +class InsufficientBalanceEvent(TypedDict): + """Full event structure for insufficient balance error.""" + + type: Literal[ChatEventType.INSUFFICIENT_BALANCE] + data: InsufficientBalanceData + + +class ThinkingStartEvent(TypedDict): + """Full event structure for thinking start.""" + + type: Literal[ChatEventType.THINKING_START] + data: ThinkingStartData + + +class ThinkingChunkEvent(TypedDict): + """Full event structure for thinking chunk.""" + + type: Literal[ChatEventType.THINKING_CHUNK] + data: ThinkingChunkData + + +class ThinkingEndEvent(TypedDict): + """Full event structure for thinking end.""" + + type: Literal[ChatEventType.THINKING_END] + data: ThinkingEndData + + +class AgentStartEvent(TypedDict): + """Full event structure for agent start.""" + + type: Literal[ChatEventType.AGENT_START] + data: AgentStartData + + +class AgentEndEvent(TypedDict): + """Full event structure for agent end.""" + + type: Literal[ChatEventType.AGENT_END] + data: AgentEndData + + +class AgentErrorEvent(TypedDict): + """Full event structure for agent error.""" + + type: Literal[ChatEventType.AGENT_ERROR] + data: AgentErrorData + + +class PhaseStartEvent(TypedDict): + """Full event structure for phase start.""" + + type: Literal[ChatEventType.PHASE_START] + data: PhaseStartData + + +class PhaseEndEvent(TypedDict): + """Full event structure for phase end.""" + + type: Literal[ChatEventType.PHASE_END] + data: PhaseEndData + + +class NodeStartEvent(TypedDict): + """Full event structure for node start.""" + + type: Literal[ChatEventType.NODE_START] + data: NodeStartData + + +class NodeEndEvent(TypedDict): + """Full event structure for node end.""" + + type: Literal[ChatEventType.NODE_END] + data: NodeEndData + + +class SubagentStartEvent(TypedDict): + """Full event structure for subagent start.""" + + type: Literal[ChatEventType.SUBAGENT_START] + data: SubagentStartData + + +class SubagentEndEvent(TypedDict): + """Full event structure for subagent end.""" + + type: Literal[ChatEventType.SUBAGENT_END] + data: SubagentEndData + + +class ProgressUpdateEvent(TypedDict): + """Full event structure for progress updates.""" + + type: Literal[ChatEventType.PROGRESS_UPDATE] + data: ProgressUpdateData + + +class IterationStartEvent(TypedDict): + """Full event structure for iteration start.""" + + type: Literal[ChatEventType.ITERATION_START] + data: IterationStartData + + +class IterationEndEvent(TypedDict): + """Full event structure for iteration end.""" + + type: Literal[ChatEventType.ITERATION_END] + data: IterationEndData + + +class StateUpdateEvent(TypedDict): + """Full event structure for state update.""" + + type: Literal[ChatEventType.STATE_UPDATE] + data: StateUpdateData + + +class HumanInputRequiredEvent(TypedDict): + """Full event structure for human input required.""" + + type: Literal[ChatEventType.HUMAN_INPUT_REQUIRED] + data: HumanInputRequiredData + + +class HumanInputReceivedEvent(TypedDict): + """Full event structure for human input received.""" + + type: Literal[ChatEventType.HUMAN_INPUT_RECEIVED] + data: HumanInputReceivedData + + +# ============================================================================= +# Union type for generic event handling +# ============================================================================= + +# Type alias for any streaming event +StreamingEvent = ( + StreamingStartEvent + | StreamingChunkEvent + | StreamingEndEvent + | ProcessingEvent + | LoadingEvent + | ErrorEvent + | ToolCallRequestEvent + | ToolCallResponseEvent + | TokenUsageEvent + | SearchCitationsEvent + | GeneratedFilesEvent + | MessageSavedEvent + | MessageEvent + | InsufficientBalanceEvent + | ThinkingStartEvent + | ThinkingChunkEvent + | ThinkingEndEvent + | AgentStartEvent + | AgentEndEvent + | AgentErrorEvent + | PhaseStartEvent + | PhaseEndEvent + | NodeStartEvent + | NodeEndEvent + | SubagentStartEvent + | SubagentEndEvent + | ProgressUpdateEvent + | IterationStartEvent + | IterationEndEvent + | StateUpdateEvent + | HumanInputRequiredEvent + | HumanInputReceivedEvent +) + + +AgentEvent = ( + AgentStartEvent + | AgentEndEvent + | AgentErrorEvent + | PhaseStartEvent + | PhaseEndEvent + | NodeStartEvent + | NodeEndEvent + | SubagentStartEvent + | SubagentEndEvent + | ProgressUpdateEvent + | IterationStartEvent + | IterationEndEvent + | StateUpdateEvent + | HumanInputRequiredEvent + | HumanInputReceivedEvent +) + + +__all__ = [ + # Data types + "StreamingStartData", + "StreamingChunkData", + "StreamingEndData", + "ProcessingData", + "LoadingData", + "ErrorData", + "ToolCallRequestData", + "ToolCallResponseData", + "TokenUsageData", + "CitationData", + "SearchCitationsData", + "GeneratedFileInfo", + "GeneratedFilesData", + "MessageSavedData", + "MessageData", + "InsufficientBalanceData", + "ThinkingStartData", + "ThinkingChunkData", + "ThinkingEndData", + # Event types + "StreamingStartEvent", + "StreamingChunkEvent", + "StreamingEndEvent", + "ProcessingEvent", + "LoadingEvent", + "ErrorEvent", + "ToolCallRequestEvent", + "ToolCallResponseEvent", + "TokenUsageEvent", + "SearchCitationsEvent", + "GeneratedFilesEvent", + "MessageSavedEvent", + "MessageEvent", + "InsufficientBalanceEvent", + "ThinkingStartEvent", + "ThinkingChunkEvent", + "ThinkingEndEvent", + "AgentStartData", + "AgentEndData", + "AgentErrorData", + "AgentExecutionContext", + "NodeStartData", + "NodeEndData", + "ProgressUpdateData", + "PhaseStartData", + "PhaseEndData", + "SubagentStartData", + "SubagentEndData", + "IterationStartData", + "IterationEndData", + "StateUpdateData", + "HumanInputRequiredData", + "HumanInputReceivedData", + "AgentStartEvent", + "AgentEndEvent", + "AgentErrorEvent", + "PhaseStartEvent", + "PhaseEndEvent", + "NodeStartEvent", + "NodeEndEvent", + "ProgressUpdateEvent", + "SubagentStartEvent", + "SubagentEndEvent", + "IterationStartEvent", + "IterationEndEvent", + "StateUpdateEvent", + "HumanInputRequiredEvent", + "HumanInputReceivedEvent", + # Union + "StreamingEvent", + "AgentEvent", +] diff --git a/service/app/schemas/chat_event_types.py b/service/app/schemas/chat_event_types.py index 2cf4443c..483e5f31 100644 --- a/service/app/schemas/chat_event_types.py +++ b/service/app/schemas/chat_event_types.py @@ -1,470 +1,156 @@ """ -Typed data structures for chat streaming events. +Centralized chat event constants. -This module provides TypedDict classes that define the exact shape of data -payloads for each chat event type. Use these instead of dict[str, Any] for -better type safety and IDE autocompletion. +Use StrEnum so enum values behave like strings in JSON payloads while +providing type safety and autocompletion across the codebase. Example: - from app.schemas.chat_event_types import StreamingChunkEvent, StreamingChunkData - - event: StreamingChunkEvent = { - "type": ChatEventType.STREAMING_CHUNK, - "data": {"id": "stream_123", "content": "Hello"} - } + from app.schemas.chat_event_types import ChatEventType + if event_type == ChatEventType.STREAMING_START: + ... """ -from typing import Any, TypedDict - -from typing_extensions import Literal, NotRequired - -from app.schemas.chat_events import ChatEventType - -# ============================================================================= -# Data Payloads (the "data" field of each event) -# ============================================================================= - - -class StreamingStartData(TypedDict): - """Data payload for STREAMING_START event.""" - - id: str - - -class StreamingChunkData(TypedDict): - """Data payload for STREAMING_CHUNK event.""" - - id: str - content: str - - -class StreamingEndData(TypedDict): - """Data payload for STREAMING_END event.""" - - id: str - created_at: float - content: NotRequired[str] # Optional content for final streaming result - - -class ProcessingData(TypedDict): - """Data payload for PROCESSING event.""" - - status: str - - -class LoadingData(TypedDict): - """Data payload for LOADING event.""" - - message: str - - -class ErrorData(TypedDict): - """Data payload for ERROR event.""" - - error: str - - -class ToolCallRequestData(TypedDict): - """Data payload for TOOL_CALL_REQUEST event.""" - - id: str - name: str - description: str - arguments: dict[str, Any] - status: str - timestamp: float - - -class ToolCallResponseData(TypedDict): - """Data payload for TOOL_CALL_RESPONSE event.""" - - toolCallId: str - status: str - result: str - error: NotRequired[str] - - -class TokenUsageData(TypedDict): - """Data payload for TOKEN_USAGE event.""" - - input_tokens: int - output_tokens: int - total_tokens: int - - -class CitationData(TypedDict): - """Single citation entry within SearchCitationsData.""" - - url: str - title: str | None - cited_text: str | None - start_index: int | None - end_index: int | None - search_queries: NotRequired[list[str]] - - -class SearchCitationsData(TypedDict): - """Data payload for SEARCH_CITATIONS event.""" - - citations: list[CitationData] - - -class GeneratedFileInfo(TypedDict): - """Single file entry within GeneratedFilesData.""" - - id: str - name: str - type: str - size: int - category: str - download_url: str - - -class GeneratedFilesData(TypedDict): - """Data payload for GENERATED_FILES event.""" - - files: list[GeneratedFileInfo] - - -class MessageSavedData(TypedDict): - """Data payload for MESSAGE_SAVED event.""" - - stream_id: str - db_id: str - created_at: str | None - - -class MessageData(TypedDict): - """Data payload for MESSAGE event (non-streaming response).""" - - id: str - content: str - - -class InsufficientBalanceData(TypedDict): - """Data payload for insufficient_balance event.""" - - error_code: str - message: str - message_cn: NotRequired[str] - details: NotRequired[dict[str, Any]] - action_required: str - - -class ThinkingStartData(TypedDict): - """Data payload for THINKING_START event.""" - - id: str - - -class ThinkingChunkData(TypedDict): - """Data payload for THINKING_CHUNK event.""" - - id: str - content: str - - -class ThinkingEndData(TypedDict): - """Data payload for THINKING_END event.""" +from enum import StrEnum +from typing import FrozenSet - id: str +class ChatEventType(StrEnum): + """Server -> Client event types used across chat flows.""" -class AgentContextData(TypedDict): - """Common context shared by agent and node events.""" + # Generic + MESSAGE = "message" + LOADING = "loading" + ERROR = "error" + PROCESSING = "processing" - agent_id: str - agent_name: str - agent_type: str - execution_id: str - depth: int - execution_path: list[str] - started_at: int - current_node: NotRequired[str] + # Streaming lifecycle + STREAMING_START = "streaming_start" + STREAMING_CHUNK = "streaming_chunk" + STREAMING_END = "streaming_end" + # Tool invocation + TOOL_CALL_REQUEST = "tool_call_request" + TOOL_CALL_RESPONSE = "tool_call_response" -class AgentStartData(TypedDict): - """Data payload for AGENT_START event.""" + # Post-processing/ack + MESSAGE_SAVED = "message_saved" - context: AgentContextData + # Token usage tracking + TOKEN_USAGE = "token_usage" + # Built-in search citations + SEARCH_CITATIONS = "search_citations" -class AgentEndData(TypedDict): - """Data payload for AGENT_END event.""" + # Generated content + GENERATED_FILES = "generated_files" - context: AgentContextData - status: str - duration_ms: int + # Balance/billing events + INSUFFICIENT_BALANCE = "insufficient_balance" + # Thinking/reasoning content (for models like Claude, DeepSeek R1, OpenAI o1) + THINKING_START = "thinking_start" + THINKING_CHUNK = "thinking_chunk" + THINKING_END = "thinking_end" -class NodeStartData(TypedDict): - """Data payload for NODE_START event.""" + # === Agent Execution Events (for complex graph-based agents) === - node_id: str - node_name: str - node_type: str - context: AgentContextData + # Agent lifecycle + AGENT_START = "agent_start" + AGENT_END = "agent_end" + AGENT_ERROR = "agent_error" + # Phase/workflow stage execution + PHASE_START = "phase_start" + PHASE_END = "phase_end" -class NodeEndData(TypedDict): - """Data payload for NODE_END event.""" + # Individual node execution + NODE_START = "node_start" + NODE_END = "node_end" - node_id: str - node_name: str - node_type: str - status: str - duration_ms: int - context: AgentContextData + # Subagent execution (nested agents) + SUBAGENT_START = "subagent_start" + SUBAGENT_END = "subagent_end" + # Progress updates + PROGRESS_UPDATE = "progress_update" -class ProgressUpdateData(TypedDict): - """Data payload for PROGRESS_UPDATE event.""" + # State changes (non-sensitive updates) + STATE_UPDATE = "state_update" - progress_percent: int - message: str - context: AgentContextData + # Iteration events (for loops in graph execution) + ITERATION_START = "iteration_start" + ITERATION_END = "iteration_end" + # Human-in-the-loop events + HUMAN_INPUT_REQUIRED = "human_input_required" + HUMAN_INPUT_RECEIVED = "human_input_received" -# ============================================================================= -# Full Event Structures (type + data) -# ============================================================================= +class ChatClientEventType(StrEnum): + """Client -> Server event types (messages coming from the frontend).""" -class StreamingStartEvent(TypedDict): - """Full event structure for streaming start.""" + # Regular chat message (default when no explicit type provided) + MESSAGE = "message" - type: Literal[ChatEventType.STREAMING_START] - data: StreamingStartData + # Tool confirmation workflow + TOOL_CALL_CONFIRM = "tool_call_confirm" + TOOL_CALL_CANCEL = "tool_call_cancel" -class StreamingChunkEvent(TypedDict): - """Full event structure for streaming chunk.""" +class ToolCallStatus(StrEnum): + """Status values for tool call lifecycle.""" - type: Literal[ChatEventType.STREAMING_CHUNK] - data: StreamingChunkData + EXECUTING = "executing" + COMPLETED = "completed" + FAILED = "failed" -class StreamingEndEvent(TypedDict): - """Full event structure for streaming end.""" +class ProcessingStatus(StrEnum): + """Status values used with the PROCESSING event.""" - type: Literal[ChatEventType.STREAMING_END] - data: StreamingEndData + PREPARING_REQUEST = "preparing_request" + PREPARING_GRAPH_EXECUTION = "preparing_graph_execution" + EXECUTING_GRAPH = "executing_graph" + PROCESSING_GRAPH_RESULT = "processing_graph_result" -class ProcessingEvent(TypedDict): - """Full event structure for processing status.""" - - type: Literal[ChatEventType.PROCESSING] - data: ProcessingData - - -class LoadingEvent(TypedDict): - """Full event structure for loading status.""" - - type: Literal[ChatEventType.LOADING] - data: LoadingData - - -class ErrorEvent(TypedDict): - """Full event structure for errors.""" - - type: Literal[ChatEventType.ERROR] - data: ErrorData - - -class ToolCallRequestEvent(TypedDict): - """Full event structure for tool call request.""" - - type: Literal[ChatEventType.TOOL_CALL_REQUEST] - data: ToolCallRequestData - - -class ToolCallResponseEvent(TypedDict): - """Full event structure for tool call response.""" - - type: Literal[ChatEventType.TOOL_CALL_RESPONSE] - data: ToolCallResponseData - - -class TokenUsageEvent(TypedDict): - """Full event structure for token usage.""" - - type: Literal[ChatEventType.TOKEN_USAGE] - data: TokenUsageData - - -class SearchCitationsEvent(TypedDict): - """Full event structure for search citations.""" - - type: Literal[ChatEventType.SEARCH_CITATIONS] - data: SearchCitationsData - - -class GeneratedFilesEvent(TypedDict): - """Full event structure for generated files.""" - - type: Literal[ChatEventType.GENERATED_FILES] - data: GeneratedFilesData - - -class MessageSavedEvent(TypedDict): - """Full event structure for message saved confirmation.""" - - type: Literal[ChatEventType.MESSAGE_SAVED] - data: MessageSavedData - - -class MessageEvent(TypedDict): - """Full event structure for non-streaming message.""" - - type: Literal[ChatEventType.MESSAGE] - data: MessageData - - -class InsufficientBalanceEvent(TypedDict): - """Full event structure for insufficient balance error.""" - - type: Literal[ChatEventType.INSUFFICIENT_BALANCE] - data: InsufficientBalanceData - - -class ThinkingStartEvent(TypedDict): - """Full event structure for thinking start.""" - - type: Literal[ChatEventType.THINKING_START] - data: ThinkingStartData - - -class ThinkingChunkEvent(TypedDict): - """Full event structure for thinking chunk.""" - - type: Literal[ChatEventType.THINKING_CHUNK] - data: ThinkingChunkData - - -class ThinkingEndEvent(TypedDict): - """Full event structure for thinking end.""" - - type: Literal[ChatEventType.THINKING_END] - data: ThinkingEndData - - -class AgentStartEvent(TypedDict): - """Full event structure for agent start.""" - - type: Literal[ChatEventType.AGENT_START] - data: AgentStartData - - -class AgentEndEvent(TypedDict): - """Full event structure for agent end.""" - - type: Literal[ChatEventType.AGENT_END] - data: AgentEndData - - -class NodeStartEvent(TypedDict): - """Full event structure for node start.""" - - type: Literal[ChatEventType.NODE_START] - data: NodeStartData - - -class NodeEndEvent(TypedDict): - """Full event structure for node end.""" - - type: Literal[ChatEventType.NODE_END] - data: NodeEndData - - -class ProgressUpdateEvent(TypedDict): - """Full event structure for progress updates.""" - - type: Literal[ChatEventType.PROGRESS_UPDATE] - data: ProgressUpdateData - - -# ============================================================================= -# Union type for generic event handling -# ============================================================================= +# Helpful groupings for conditional logic +SERVER_STREAMING_EVENTS: FrozenSet[ChatEventType] = frozenset( + { + ChatEventType.STREAMING_START, + ChatEventType.STREAMING_CHUNK, + ChatEventType.STREAMING_END, + } +) -# Type alias for any streaming event -StreamingEvent = ( - StreamingStartEvent - | StreamingChunkEvent - | StreamingEndEvent - | ProcessingEvent - | LoadingEvent - | ErrorEvent - | ToolCallRequestEvent - | ToolCallResponseEvent - | TokenUsageEvent - | SearchCitationsEvent - | GeneratedFilesEvent - | MessageSavedEvent - | MessageEvent - | InsufficientBalanceEvent - | ThinkingStartEvent - | ThinkingChunkEvent - | ThinkingEndEvent - | AgentStartEvent - | AgentEndEvent - | NodeStartEvent - | NodeEndEvent - | ProgressUpdateEvent +SERVER_TOOL_EVENTS: FrozenSet[ChatEventType] = frozenset( + {ChatEventType.TOOL_CALL_REQUEST, ChatEventType.TOOL_CALL_RESPONSE} ) +SERVER_AGENT_EVENTS: FrozenSet[ChatEventType] = frozenset( + { + ChatEventType.AGENT_START, + ChatEventType.AGENT_END, + ChatEventType.AGENT_ERROR, + ChatEventType.PHASE_START, + ChatEventType.PHASE_END, + ChatEventType.NODE_START, + ChatEventType.NODE_END, + ChatEventType.SUBAGENT_START, + ChatEventType.SUBAGENT_END, + ChatEventType.PROGRESS_UPDATE, + ChatEventType.STATE_UPDATE, + ChatEventType.ITERATION_START, + ChatEventType.ITERATION_END, + } +) __all__ = [ - # Data types - "StreamingStartData", - "StreamingChunkData", - "StreamingEndData", - "ProcessingData", - "LoadingData", - "ErrorData", - "ToolCallRequestData", - "ToolCallResponseData", - "TokenUsageData", - "CitationData", - "SearchCitationsData", - "GeneratedFileInfo", - "GeneratedFilesData", - "MessageSavedData", - "MessageData", - "InsufficientBalanceData", - "ThinkingStartData", - "ThinkingChunkData", - "ThinkingEndData", - # Event types - "StreamingStartEvent", - "StreamingChunkEvent", - "StreamingEndEvent", - "ProcessingEvent", - "LoadingEvent", - "ErrorEvent", - "ToolCallRequestEvent", - "ToolCallResponseEvent", - "TokenUsageEvent", - "SearchCitationsEvent", - "GeneratedFilesEvent", - "MessageSavedEvent", - "MessageEvent", - "InsufficientBalanceEvent", - "ThinkingStartEvent", - "ThinkingChunkEvent", - "ThinkingEndEvent", - "AgentContextData", - "AgentStartData", - "AgentEndData", - "NodeStartData", - "NodeEndData", - "ProgressUpdateData", - "AgentStartEvent", - "AgentEndEvent", - "NodeStartEvent", - "NodeEndEvent", - "ProgressUpdateEvent", - # Union - "StreamingEvent", + "ChatEventType", + "ChatClientEventType", + "ToolCallStatus", + "ProcessingStatus", + "SERVER_STREAMING_EVENTS", + "SERVER_TOOL_EVENTS", + "SERVER_AGENT_EVENTS", ] diff --git a/service/app/schemas/chat_events.py b/service/app/schemas/chat_events.py deleted file mode 100644 index 5db68be0..00000000 --- a/service/app/schemas/chat_events.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Centralized chat event constants. - -Use StrEnum so enum values behave like strings in JSON payloads while -providing type safety and autocompletion across the codebase. - -Example: - from app.schemas.chat_events import ChatEventType - if event_type == ChatEventType.STREAMING_START: - ... -""" - -from enum import StrEnum -from typing import FrozenSet - - -class ChatEventType(StrEnum): - """Server -> Client event types used across chat flows.""" - - # Generic - MESSAGE = "message" - LOADING = "loading" - ERROR = "error" - PROCESSING = "processing" - - # Streaming lifecycle - STREAMING_START = "streaming_start" - STREAMING_CHUNK = "streaming_chunk" - STREAMING_END = "streaming_end" - - # Tool invocation - TOOL_CALL_REQUEST = "tool_call_request" - TOOL_CALL_RESPONSE = "tool_call_response" - - # Post-processing/ack - MESSAGE_SAVED = "message_saved" - - # Token usage tracking - TOKEN_USAGE = "token_usage" - - # Built-in search citations - SEARCH_CITATIONS = "search_citations" - - # Generated content - GENERATED_FILES = "generated_files" - - # Balance/billing events - INSUFFICIENT_BALANCE = "insufficient_balance" - - # Thinking/reasoning content (for models like Claude, DeepSeek R1, OpenAI o1) - THINKING_START = "thinking_start" - THINKING_CHUNK = "thinking_chunk" - THINKING_END = "thinking_end" - - # === Agent Execution Events (for complex graph-based agents) === - - # Agent lifecycle - AGENT_START = "agent_start" - AGENT_END = "agent_end" - AGENT_ERROR = "agent_error" - - # Phase/workflow stage execution - PHASE_START = "phase_start" - PHASE_END = "phase_end" - - # Individual node execution - NODE_START = "node_start" - NODE_END = "node_end" - - # Subagent execution (nested agents) - SUBAGENT_START = "subagent_start" - SUBAGENT_END = "subagent_end" - - # Progress updates - PROGRESS_UPDATE = "progress_update" - - # State changes (non-sensitive updates) - STATE_UPDATE = "state_update" - - # Iteration events (for loops in graph execution) - ITERATION_START = "iteration_start" - ITERATION_END = "iteration_end" - - # Human-in-the-loop events - HUMAN_INPUT_REQUIRED = "human_input_required" - HUMAN_INPUT_RECEIVED = "human_input_received" - - -class ChatClientEventType(StrEnum): - """Client -> Server event types (messages coming from the frontend).""" - - # Regular chat message (default when no explicit type provided) - MESSAGE = "message" - - # Tool confirmation workflow - TOOL_CALL_CONFIRM = "tool_call_confirm" - TOOL_CALL_CANCEL = "tool_call_cancel" - - -class ToolCallStatus(StrEnum): - """Status values for tool call lifecycle.""" - - EXECUTING = "executing" - COMPLETED = "completed" - FAILED = "failed" - - -class ProcessingStatus(StrEnum): - """Status values used with the PROCESSING event.""" - - PREPARING_REQUEST = "preparing_request" - PREPARING_GRAPH_EXECUTION = "preparing_graph_execution" - EXECUTING_GRAPH = "executing_graph" - PROCESSING_GRAPH_RESULT = "processing_graph_result" - - -# Helpful groupings for conditional logic -SERVER_STREAMING_EVENTS: FrozenSet[ChatEventType] = frozenset( - { - ChatEventType.STREAMING_START, - ChatEventType.STREAMING_CHUNK, - ChatEventType.STREAMING_END, - } -) - -SERVER_TOOL_EVENTS: FrozenSet[ChatEventType] = frozenset( - {ChatEventType.TOOL_CALL_REQUEST, ChatEventType.TOOL_CALL_RESPONSE} -) - -SERVER_AGENT_EVENTS: FrozenSet[ChatEventType] = frozenset( - { - ChatEventType.AGENT_START, - ChatEventType.AGENT_END, - ChatEventType.AGENT_ERROR, - ChatEventType.PHASE_START, - ChatEventType.PHASE_END, - ChatEventType.NODE_START, - ChatEventType.NODE_END, - ChatEventType.SUBAGENT_START, - ChatEventType.SUBAGENT_END, - ChatEventType.PROGRESS_UPDATE, - ChatEventType.STATE_UPDATE, - ChatEventType.ITERATION_START, - ChatEventType.ITERATION_END, - } -) - -__all__ = [ - "ChatEventType", - "ChatClientEventType", - "ToolCallStatus", - "ProcessingStatus", - "SERVER_STREAMING_EVENTS", - "SERVER_TOOL_EVENTS", - "SERVER_AGENT_EVENTS", -] diff --git a/service/migrations/versions/a05b7d38a4a0_add_graph_config.py b/service/migrations/versions/a05b7d38a4a0_add_graph_config.py new file mode 100644 index 00000000..7d9c0d6f --- /dev/null +++ b/service/migrations/versions/a05b7d38a4a0_add_graph_config.py @@ -0,0 +1,33 @@ +"""add_graph_config + +Revision ID: a05b7d38a4a0 +Revises: a7b3c8e51f92 +Create Date: 2026-01-07 22:29:03.211631 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "a05b7d38a4a0" +down_revision: Union[str, Sequence[str], None] = "a7b3c8e51f92" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("agent", sa.Column("graph_config", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("agent", "graph_config") + # ### end Alembic commands ### diff --git a/service/migrations/versions/a7b3c8e51f92_drop_legacy_graph_tables.py b/service/migrations/versions/a7b3c8e51f92_drop_legacy_graph_tables.py new file mode 100644 index 00000000..d78a289b --- /dev/null +++ b/service/migrations/versions/a7b3c8e51f92_drop_legacy_graph_tables.py @@ -0,0 +1,110 @@ +"""drop_legacy_graph_tables + +Revision ID: a7b3c8e51f92 +Revises: d25101ce4d9a +Create Date: 2026-01-07 + +Drop the legacy graphagent, graphnode, and graphedge tables. +These tables are replaced by the new graph_config JSON field in the agent table. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a7b3c8e51f92" +down_revision: Union[str, Sequence[str], None] = "d25101ce4d9a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Drop legacy graph tables.""" + # Drop graphnode table and its indexes + op.drop_index(op.f("ix_graphnode_id"), table_name="graphnode") + op.drop_index(op.f("ix_graphnode_graph_agent_id"), table_name="graphnode") + op.drop_table("graphnode") + + # Drop graphedge table and its indexes + op.drop_index(op.f("ix_graphedge_to_node_id"), table_name="graphedge") + op.drop_index(op.f("ix_graphedge_id"), table_name="graphedge") + op.drop_index(op.f("ix_graphedge_graph_agent_id"), table_name="graphedge") + op.drop_index(op.f("ix_graphedge_from_node_id"), table_name="graphedge") + op.drop_table("graphedge") + + # Drop graphagent table and its indexes + op.drop_index(op.f("ix_graphagent_user_id"), table_name="graphagent") + op.drop_index(op.f("ix_graphagent_parent_agent_id"), table_name="graphagent") + op.drop_index(op.f("ix_graphagent_id"), table_name="graphagent") + # Drop is_published and is_official columns if they exist (added in later migrations) + bind = op.get_bind() + inspector = sa.inspect(bind) + columns = [col["name"] for col in inspector.get_columns("graphagent")] + if "is_published" in columns: + op.drop_index(op.f("ix_graphagent_is_published"), table_name="graphagent") + if "is_official" in columns: + op.drop_index(op.f("ix_graphagent_is_official"), table_name="graphagent") + op.drop_table("graphagent") + + +def downgrade() -> None: + """Recreate legacy graph tables (for rollback).""" + # Recreate graphagent table + op.create_table( + "graphagent", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.Column("state_schema", sa.JSON(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("parent_agent_id", sa.Uuid(), nullable=True), + sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_published", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("is_official", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_graphagent_id"), "graphagent", ["id"], unique=False) + op.create_index(op.f("ix_graphagent_parent_agent_id"), "graphagent", ["parent_agent_id"], unique=False) + op.create_index(op.f("ix_graphagent_user_id"), "graphagent", ["user_id"], unique=False) + op.create_index(op.f("ix_graphagent_is_published"), "graphagent", ["is_published"], unique=False) + op.create_index(op.f("ix_graphagent_is_official"), "graphagent", ["is_official"], unique=False) + + # Recreate graphedge table + op.create_table( + "graphedge", + sa.Column("from_node_id", sa.Uuid(), nullable=False), + sa.Column("to_node_id", sa.Uuid(), nullable=False), + sa.Column("condition", sa.JSON(), nullable=True), + sa.Column("graph_agent_id", sa.Uuid(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_graphedge_from_node_id"), "graphedge", ["from_node_id"], unique=False) + op.create_index(op.f("ix_graphedge_graph_agent_id"), "graphedge", ["graph_agent_id"], unique=False) + op.create_index(op.f("ix_graphedge_id"), "graphedge", ["id"], unique=False) + op.create_index(op.f("ix_graphedge_to_node_id"), "graphedge", ["to_node_id"], unique=False) + + # Recreate graphnode table + op.create_table( + "graphnode", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), + sa.Column("node_type", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column("config", sa.JSON(), nullable=True), + sa.Column("graph_agent_id", sa.Uuid(), nullable=False), + sa.Column("position_x", sa.Float(), nullable=True), + sa.Column("position_y", sa.Float(), nullable=True), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_graphnode_graph_agent_id"), "graphnode", ["graph_agent_id"], unique=False) + op.create_index(op.f("ix_graphnode_id"), "graphnode", ["id"], unique=False) diff --git a/service/migrations/versions/f0d7b93430e1_add_agent_metadata.py b/service/migrations/versions/f0d7b93430e1_add_agent_metadata.py new file mode 100644 index 00000000..ab9fe05c --- /dev/null +++ b/service/migrations/versions/f0d7b93430e1_add_agent_metadata.py @@ -0,0 +1,33 @@ +"""add_agent_metadata + +Revision ID: f0d7b93430e1 +Revises: a05b7d38a4a0 +Create Date: 2026-01-07 23:26:44.557183 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f0d7b93430e1" +down_revision: Union[str, Sequence[str], None] = "a05b7d38a4a0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message", sa.Column("agent_metadata", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "agent_metadata") + # ### end Alembic commands ### diff --git a/service/pyproject.toml b/service/pyproject.toml index 3195506e..2642cadc 100644 --- a/service/pyproject.toml +++ b/service/pyproject.toml @@ -40,12 +40,15 @@ dependencies = [ "websockets>=13.0,<14.0", "reportlab>=4.4.7", "langchain-qwq>=0.3.1", + "jinja2>=3.1.6", ] [dependency-groups] dev = [ "types-pyjwt>=1.7.1", "types-requests>=2.32.4.20250611", + "types-aioboto3>=15.5.0", + "types-openpyxl>=3.1.5.20250919", "pyright>=1.1.390", "pre-commit>=4.2.0", "ruff>=0.14.3", @@ -100,10 +103,10 @@ venv = ".venv" # strictSetInference = true # 严格推断集合类型 # strictParameterNoneValue = true # 严格检查参数默认值 None 的类型匹配 -# ============================================================================ -# 报告配置 - 错误级别 (error) -# ============================================================================ -# 这些问题会导致类型检查失败,必须修复 +# reportMissingTypeStubs = false +# reportUnknownMemberType = "warning" +# reportUnknownArgumentType = "warning" +# reportUnknownVariableType = "warning" reportMissingImports = "error" # 缺失的导入(模块不存在) reportDuplicateImport = "error" # 重复导入同一模块 @@ -140,7 +143,7 @@ reportUnnecessaryContains = "warning" # 不必要的 in 检查 reportUnnecessaryTypeIgnoreComment = "warning" # 不必要的 type: ignore 注释(强制审查每个 type: ignore 的必要性) reportTypeCommentUsage = "warning" # 使用类型注释而非类型注解 reportUnknownParameterType = "warning" # TODO: 参数类型未知,未来升级为 error,严格禁止 -reportUnknownArgumentType = "warning" # TODO: 参数类型未知,未来升级为 error,严格禁止 +reportUnknownArgumentType = "warning" # Disabled: LangChain/LangGraph stubs have incomplete types # ============================================================================ # 报告配置 - 禁用项 (none) diff --git a/service/tests/integration/test_integration.py b/service/tests/integration/test_integration.py index 13f2e963..cdbd7f7a 100644 --- a/service/tests/integration/test_integration.py +++ b/service/tests/integration/test_integration.py @@ -1,8 +1,8 @@ """Integration tests for the Xyzen service.""" import pytest -from httpx import AsyncClient from fastapi.testclient import TestClient +from httpx import AsyncClient from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/service/tests/unit/handler/mcp/test_knowledge_limits.py b/service/tests/unit/handler/mcp/test_knowledge_limits.py index 35f410c7..29717cf0 100644 --- a/service/tests/unit/handler/mcp/test_knowledge_limits.py +++ b/service/tests/unit/handler/mcp/test_knowledge_limits.py @@ -104,7 +104,7 @@ async def test_read_file_text_mode(mock_deps: tuple[MagicMock, MagicMock], mocke mock_get_handler.return_value = mock_handler # Test call - result: dict[str, Any] = await read_file.fn("ks123", "doc.txt", mode="text") # type: ignore + result: dict[str, Any] = await read_file.fn("ks123", "doc.txt", mode="text") # Verify assert isinstance(result, dict) diff --git a/service/tests/unit/test_core/test_thinking_events.py b/service/tests/unit/test_core/test_thinking_events.py index a4b5bb51..4a966aed 100644 --- a/service/tests/unit/test_core/test_thinking_events.py +++ b/service/tests/unit/test_core/test_thinking_events.py @@ -5,8 +5,10 @@ from various provider formats (Anthropic, DeepSeek, etc.). """ +from typing import Any + from app.core.chat.stream_handlers import ThinkingEventHandler -from app.schemas.chat_events import ChatEventType +from app.schemas.chat_event_types import ChatEventType class MockMessageChunk: @@ -14,9 +16,9 @@ class MockMessageChunk: def __init__( self, - content: str | list = "", - additional_kwargs: dict | None = None, - response_metadata: dict | None = None, + content: str | list[dict[str, Any]] = "", + additional_kwargs: dict[str, Any] | None = None, + response_metadata: dict[str, Any] | None = None, ): self.content = content self.additional_kwargs = additional_kwargs or {} diff --git a/service/uv.lock b/service/uv.lock index 535b9223..b655e2c3 100644 --- a/service/uv.lock +++ b/service/uv.lock @@ -271,6 +271,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/c5/f6ce561004db45f0b847c2cd9b19c67c6bf348a82018a48cb718be6b58b0/botocore-1.40.61-py3-none-any.whl", hash = "sha256:17ebae412692fd4824f99cde0f08d50126dc97954008e5ba2b522eb049238aa7", size = 14055973, upload-time = "2025-10-28T19:26:42.15Z" }, ] +[[package]] +name = "botocore-stubs" +version = "1.42.23" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "types-awscrt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/22/52e1ea0c727aa6b005a6d46bdfc0e7898e4b94374ef411bc211333a52bae/botocore_stubs-1.42.23.tar.gz", hash = "sha256:5388e98bed5d354e848772ef050afebab4c7aa64ef6b7aa9c03066c8fe9cacee", size = 42412, upload-time = "2026-01-06T21:27:43.134Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/36/879834d2d8be7097c25a233b0306ad967ed6f86affe192f042c2a7483f18/botocore_stubs-1.42.23-py3-none-any.whl", hash = "sha256:3687cf38a66a3e2b5771d1380592fabeba18b866729b438d1c90d26c123acd3c", size = 66761, upload-time = "2026-01-06T21:27:41.838Z" }, +] + [[package]] name = "bottleneck" version = "1.6.0" @@ -2730,6 +2742,7 @@ dependencies = [ { name = "google-genai" }, { name = "greenlet" }, { name = "httpx" }, + { name = "jinja2" }, { name = "langchain" }, { name = "langchain-anthropic" }, { name = "langchain-google-genai" }, @@ -2770,6 +2783,8 @@ dev = [ { name = "pytest-mock" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "types-aioboto3" }, + { name = "types-openpyxl" }, { name = "types-pyjwt" }, { name = "types-requests" }, { name = "watchdog" }, @@ -2787,6 +2802,7 @@ requires-dist = [ { name = "google-genai", specifier = ">=1.38.0" }, { name = "greenlet", specifier = ">=3.2.3" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "jinja2", specifier = ">=3.1.6" }, { name = "langchain", specifier = ">=1.0.2" }, { name = "langchain-anthropic", specifier = ">=1.0.0" }, { name = "langchain-google-genai", specifier = ">=3.0.0" }, @@ -2827,6 +2843,8 @@ dev = [ { name = "pytest-mock", specifier = ">=3.12.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = ">=0.14.3" }, + { name = "types-aioboto3", specifier = ">=15.5.0" }, + { name = "types-openpyxl", specifier = ">=3.1.5.20250919" }, { name = "types-pyjwt", specifier = ">=1.7.1" }, { name = "types-requests", specifier = ">=2.32.4.20250611" }, { name = "watchdog", specifier = ">=3.0.0" }, @@ -2991,6 +3009,41 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/64/7713ffe4b5983314e9d436a90d5bd4f63b6054e2aca783a3cfc44cb95bbf/typer-0.20.0-py3-none-any.whl", hash = "sha256:5b463df6793ec1dca6213a3cf4c0f03bc6e322ac5e16e13ddd622a889489784a", size = 47028, upload-time = "2025-10-20T17:03:47.617Z" }, ] +[[package]] +name = "types-aioboto3" +version = "15.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore-stubs" }, + { name = "types-aiobotocore" }, + { name = "types-s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/76/e162ea2ef8d414d4f36f28a6e0b6078ccef3f2f9d5f957859f303995c528/types_aioboto3-15.5.0.tar.gz", hash = "sha256:5769a1c3df7ca1abedf3656ddf0b970c9b0436d0f88cf4686040b55cd2a02925", size = 81059, upload-time = "2025-10-31T01:11:54.445Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/1d/e187fbe9771dffb5f0801e315ac23a6c383c14d1cbb90da6ca3ad1ea9b06/types_aioboto3-15.5.0-py3-none-any.whl", hash = "sha256:8aed7c9b6fe9b59e6ce74f7a6db7b8a9912a34c8f80ed639fac1fa59d6b20aa1", size = 42521, upload-time = "2025-10-31T01:11:47.832Z" }, +] + +[[package]] +name = "types-aiobotocore" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore-stubs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a4/5e/968935cba8e01f63b45179aab963551f6c2631d2df784de3c6ee04b279c6/types_aiobotocore-3.1.0.tar.gz", hash = "sha256:9c36d9d29044b424657900fa99e8c058f73d5a755e93d21e4bbeb0eea8f19392", size = 86416, upload-time = "2026-01-03T02:09:47.965Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/20858211d52028ff3f7704babf111a18c95b36697ce0d2063e3e8979fc8c/types_aiobotocore-3.1.0-py3-none-any.whl", hash = "sha256:4ab223580f4249a84ebac17461ff3719b541a1dc94fe0840d392e3f5d2ba7ef0", size = 54201, upload-time = "2026-01-03T02:09:41.068Z" }, +] + +[[package]] +name = "types-awscrt" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/1f/febd2df22e24f77b759db0dd9ecdd7f07f055e6a4dbbb699c5eb34b617ef/types_awscrt-0.30.0.tar.gz", hash = "sha256:362fd8f5eaebcfcd922cb9fd8274fb375df550319f78031ee3779eac0b9ecc79", size = 17761, upload-time = "2025-12-12T01:55:59.626Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/5f/15999051fca2949a67562c3f80fae2dd5d3404a3f97b326b614533843281/types_awscrt-0.30.0-py3-none-any.whl", hash = "sha256:8204126e01a00eaa4a746e7a0076538ca0e4e3f52408adec0ab9b471bb0bb64b", size = 42392, upload-time = "2025-12-12T01:55:58.194Z" }, +] + [[package]] name = "types-cryptography" version = "3.3.23.2" @@ -3000,6 +3053,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/36/92dfe7e5056694e78caefd05b383140c74c7fcbfc63d26ee514c77f2d8a2/types_cryptography-3.3.23.2-py3-none-any.whl", hash = "sha256:b965d548f148f8e87f353ccf2b7bd92719fdf6c845ff7cedf2abb393a0643e4f", size = 30223, upload-time = "2022-11-08T18:29:26.848Z" }, ] +[[package]] +name = "types-openpyxl" +version = "3.1.5.20250919" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/12/8bc4a25d49f1e4b7bbca868daa3ee80b1983d8137b4986867b5b65ab2ecd/types_openpyxl-3.1.5.20250919.tar.gz", hash = "sha256:232b5906773eebace1509b8994cdadda043f692cfdba9bfbb86ca921d54d32d7", size = 100880, upload-time = "2025-09-19T02:54:39.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/3c/d49cf3f4489a10e9ddefde18fd258f120754c5825d06d145d9a0aaac770b/types_openpyxl-3.1.5.20250919-py3-none-any.whl", hash = "sha256:bd06f18b12fd5e1c9f0b666ee6151d8140216afa7496f7ebb9fe9d33a1a3ce99", size = 166078, upload-time = "2025-09-19T02:54:38.657Z" }, +] + [[package]] name = "types-pyjwt" version = "1.7.1" @@ -3024,6 +3086,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/20/9a227ea57c1285986c4cf78400d0a91615d25b24e257fd9e2969606bdfae/types_requests-2.32.4.20250913-py3-none-any.whl", hash = "sha256:78c9c1fffebbe0fa487a418e0fa5252017e9c60d1a2da394077f1780f655d7e1", size = 20658, upload-time = "2025-09-13T02:40:01.115Z" }, ] +[[package]] +name = "types-s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/64/42689150509eb3e6e82b33ee3d89045de1592488842ddf23c56957786d05/types_s3transfer-0.16.0.tar.gz", hash = "sha256:b4636472024c5e2b62278c5b759661efeb52a81851cde5f092f24100b1ecb443", size = 13557, upload-time = "2025-12-08T08:13:09.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/27/e88220fe6274eccd3bdf95d9382918716d312f6f6cef6a46332d1ee2feff/types_s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:1c0cd111ecf6e21437cb410f5cddb631bfb2263b77ad973e79b9c6d0cb24e0ef", size = 19247, upload-time = "2025-12-08T08:13:08.426Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 28a16d39d7960eeba463f511bbbf149fec1aa87e Mon Sep 17 00:00:00 2001 From: xinquiry Date: Sun, 11 Jan 2026 04:56:42 +0800 Subject: [PATCH 3/4] feat: implement the deep research agent --- service/app/agents/__init__.py | 88 +- service/app/agents/base_graph_agent.py | 11 +- service/app/agents/components/__init__.py | 285 ++++ service/app/agents/components/base.py | 184 +++ service/app/agents/factory.py | 314 ++++- service/app/agents/graph_builder.py | 804 ++++++++++++ service/app/agents/react_agent.py | 139 -- service/app/agents/system/__init__.py | 219 ++++ service/app/agents/system/base.py | 209 +++ .../agents/system/deep_research/__init__.py | 54 + .../app/agents/system/deep_research/agent.py | 200 +++ .../agents/system/deep_research/components.py | 373 ++++++ .../system/deep_research/configuration.py | 53 + .../system/deep_research/graph_config.py | 380 ++++++ .../agents/system/deep_research/prompts.py | 343 +++++ .../app/agents/system/deep_research/state.py | 83 ++ .../app/agents/system/deep_research/utils.py | 201 +++ service/app/agents/system/react/__init__.py | 19 + service/app/agents/system/react/agent.py | 204 +++ service/app/agents/types.py | 154 +++ service/app/api/v1/agents.py | 78 +- service/app/api/v1/files.py | 26 +- service/app/api/ws/v1/chat.py | 2 +- service/app/core/chat/agent_event_handler.py | 417 ++++++ service/app/core/chat/history.py | 21 +- service/app/core/chat/langchain.py | 130 +- service/app/core/chat/stream_handlers.py | 203 ++- service/app/mcp/graph_tools.py | 1162 ----------------- service/app/repos/message.py | 1 + service/app/schemas/graph_config.py | 38 + service/app/tasks/chat.py | 27 +- .../components/AgentExecutionBubble.tsx | 295 +++++ .../components/AgentExecutionTimeline.tsx | 324 +++++ .../layouts/components/AgentNodeItem.tsx | 182 +++ .../layouts/components/AgentPhaseCard.tsx | 339 +++++ .../layouts/components/AgentPhaseItem.tsx | 172 +++ .../layouts/components/AgentProgressBar.tsx | 105 ++ .../layouts/components/ChatBubble.tsx | 39 + .../layouts/components/NodeExecutionCard.tsx | 293 +++++ web/src/components/modals/AddAgentModal.tsx | 483 +++++-- web/src/components/modals/EditAgentModal.tsx | 90 +- web/src/core/chat/index.ts | 1 + web/src/core/chat/messageProcessor.ts | 155 ++- web/src/core/chat/types.ts | 31 +- web/src/service/xyzenService.ts | 42 +- web/src/store/slices/agentSlice.ts | 79 +- web/src/store/slices/chatSlice.ts | 503 ++++++- web/src/store/types.ts | 72 + web/src/types/agentEvents.ts | 238 ++++ web/src/types/agents.ts | 34 + 50 files changed, 8254 insertions(+), 1645 deletions(-) create mode 100644 service/app/agents/components/__init__.py create mode 100644 service/app/agents/components/base.py create mode 100644 service/app/agents/graph_builder.py delete mode 100644 service/app/agents/react_agent.py create mode 100644 service/app/agents/system/__init__.py create mode 100644 service/app/agents/system/base.py create mode 100644 service/app/agents/system/deep_research/__init__.py create mode 100644 service/app/agents/system/deep_research/agent.py create mode 100644 service/app/agents/system/deep_research/components.py create mode 100644 service/app/agents/system/deep_research/configuration.py create mode 100644 service/app/agents/system/deep_research/graph_config.py create mode 100644 service/app/agents/system/deep_research/prompts.py create mode 100644 service/app/agents/system/deep_research/state.py create mode 100644 service/app/agents/system/deep_research/utils.py create mode 100644 service/app/agents/system/react/__init__.py create mode 100644 service/app/agents/system/react/agent.py create mode 100644 service/app/agents/types.py create mode 100644 service/app/core/chat/agent_event_handler.py delete mode 100644 service/app/mcp/graph_tools.py create mode 100644 web/src/components/layouts/components/AgentExecutionBubble.tsx create mode 100644 web/src/components/layouts/components/AgentExecutionTimeline.tsx create mode 100644 web/src/components/layouts/components/AgentNodeItem.tsx create mode 100644 web/src/components/layouts/components/AgentPhaseCard.tsx create mode 100644 web/src/components/layouts/components/AgentPhaseItem.tsx create mode 100644 web/src/components/layouts/components/AgentProgressBar.tsx create mode 100644 web/src/components/layouts/components/NodeExecutionCard.tsx create mode 100644 web/src/types/agentEvents.ts diff --git a/service/app/agents/__init__.py b/service/app/agents/__init__.py index 591650c2..25f6db21 100644 --- a/service/app/agents/__init__.py +++ b/service/app/agents/__init__.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import Any -from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from .base_graph_agent import BaseBuiltinGraphAgent @@ -34,7 +33,14 @@ def __init__(self) -> None: def _discover_agents(self) -> None: """Automatically discover all builtin graph agents in the current directory.""" current_dir = Path(__file__).parent - python_files = [f for f in current_dir.glob("*.py") if f.name not in ["__init__.py", "base_graph_agent.py"]] + excluded_files = { + "__init__.py", + "base_graph_agent.py", + "factory.py", + "graph_builder.py", + "state.pytypes.py", + } + python_files = [f for f in current_dir.glob("*.py") if f.name not in excluded_files] for file_path in python_files: module_name = file_path.stem @@ -234,83 +240,23 @@ def get_registry_stats(self) -> dict[str, Any]: async def seed_to_database(self, db: AsyncSession) -> dict[str, Any]: """ - Seed all builtin graph agents to the database with is_official=True. + Seed all builtin graph agents to the database. - This method creates or updates GraphAgent records in the database for all - registered builtin agents. It ensures that builtin agents are accessible - through the standard agent API alongside user-created agents. + NOTE: This method is deprecated. Builtin agents are now registered + through the system agent registry and don't need database seeding. + The method is kept for backwards compatibility but returns empty stats. Args: - db: Async database session + db: Async database session (unused) Returns: - dict: Statistics about the seeding operation - - created: Number of new agents created - - updated: Number of existing agents updated - - failed: Number of agents that failed to sync - - total: Total number of agents processed + dict: Empty statistics (method is a no-op) """ - from app.models.graph import GraphAgent, GraphAgentCreate - from app.repos.graph import GraphRepository - - repo = GraphRepository(db) - stats = {"created": 0, "updated": 0, "failed": 0, "total": len(self.agents)} - - for agent_name, agent_config in self.agents.items(): - try: - agent_instance = agent_config["agent"] - metadata = agent_config["metadata"] - - # Check if agent already exists by name and is_official=True - existing_query = select(GraphAgent).where( - GraphAgent.name == metadata.get("name", agent_name), - GraphAgent.is_official == True, # noqa: E712 - ) - result = await db.exec(existing_query) - existing_agent = result.first() - - # Prepare agent data - agent_data = { - "name": metadata.get("name", agent_name), - "description": metadata.get("description", ""), - "state_schema": agent_instance.get_state_schema(), - "is_active": True, - "is_published": True, # Official agents are published by default - "is_official": True, - "parent_agent_id": None, - } - - if existing_agent: - # Update existing agent - for key, value in agent_data.items(): - if key not in ["name"]: # Don't update name - setattr(existing_agent, key, value) - db.add(existing_agent) - await db.flush() - await db.refresh(existing_agent) - stats["updated"] += 1 - logger.info(f"Updated official graph agent in database: {agent_name}") - else: - # Create new agent with system user_id - agent_create = GraphAgentCreate(**agent_data) - new_agent = await repo.create_graph_agent(agent_create, user_id="system") - await db.flush() - await db.refresh(new_agent) - stats["created"] += 1 - logger.info(f"Created official graph agent in database: {agent_name}") - - except Exception as e: - stats["failed"] += 1 - logger.error(f"Failed to seed builtin agent '{agent_name}' to database: {e}") - # Continue with other agents even if one fails - logger.info( - f"Builtin agent database seeding completed: " - f"{stats['created']} created, {stats['updated']} updated, " - f"{stats['failed']} failed out of {stats['total']} total" + "seed_to_database is deprecated. Builtin agents are registered " + "through the system agent registry. No database seeding required." ) - - return stats + return {"created": 0, "updated": 0, "failed": 0, "total": 0} # Create global registry instance diff --git a/service/app/agents/base_graph_agent.py b/service/app/agents/base_graph_agent.py index 1c81185d..15010894 100644 --- a/service/app/agents/base_graph_agent.py +++ b/service/app/agents/base_graph_agent.py @@ -9,10 +9,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import Any, TypeVar -if TYPE_CHECKING: - from langgraph.graph.state import CompiledStateGraph +from langgraph.graph.state import CompiledStateGraph +from pydantic import BaseModel + +# TypeVar for state types in subclasses +StateT = TypeVar("StateT", bound=BaseModel) class AgentMetadata: @@ -107,7 +110,7 @@ def __init__( self.license_ = license_ @abstractmethod - def build_graph(self) -> "CompiledStateGraph": + def build_graph(self) -> CompiledStateGraph[Any, None, Any, Any]: """ Build and return the LangGraph StateGraph for this agent. diff --git a/service/app/agents/components/__init__.py b/service/app/agents/components/__init__.py new file mode 100644 index 00000000..a9697040 --- /dev/null +++ b/service/app/agents/components/__init__.py @@ -0,0 +1,285 @@ +""" +Component Registry - Central registry for reusable agent components. + +This module provides the ComponentRegistry class that manages registration, +discovery, and retrieval of reusable components from system agents. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from .base import ( + BaseComponent, + ComponentMetadata, + ComponentType, + NodeComponent, + PromptTemplateComponent, + StateSchemaComponent, + SubgraphComponent, +) + +logger = logging.getLogger(__name__) + + +class ComponentRegistry: + """ + Registry for reusable agent components. + + This registry allows: + - Registration of components from system agents + - Discovery of components by type, tag, or key + - Export of component configurations for use in user agents + """ + + def __init__(self) -> None: + self._components: dict[str, BaseComponent] = {} + self._by_type: dict[ComponentType, list[str]] = {t: [] for t in ComponentType} + self._by_tag: dict[str, list[str]] = {} + + def register(self, component: BaseComponent, override: bool = False) -> None: + """ + Register a component in the registry. + + Args: + component: The component to register + override: If True, allow overwriting existing components + + Raises: + ValueError: If component key already exists and override is False + """ + key = component.metadata.key + + if key in self._components and not override: + raise ValueError(f"Component '{key}' already registered. Use override=True to replace.") + + # Validate component + errors = component.validate() + if errors: + logger.warning(f"Component '{key}' has validation warnings: {errors}") + + # Register component + self._components[key] = component + + # Index by type + comp_type = component.metadata.component_type + if key not in self._by_type[comp_type]: + self._by_type[comp_type].append(key) + + # Index by tags + for tag in component.metadata.tags: + if tag not in self._by_tag: + self._by_tag[tag] = [] + if key not in self._by_tag[tag]: + self._by_tag[tag].append(key) + + logger.info(f"Registered component: {key} ({comp_type})") + + def unregister(self, key: str) -> bool: + """ + Remove a component from the registry. + + Args: + key: Component key to remove + + Returns: + True if component was removed, False if not found + """ + if key not in self._components: + return False + + component = self._components[key] + comp_type = component.metadata.component_type + + # Remove from indexes + if key in self._by_type[comp_type]: + self._by_type[comp_type].remove(key) + + for tag in component.metadata.tags: + if tag in self._by_tag and key in self._by_tag[tag]: + self._by_tag[tag].remove(key) + + # Remove from main registry + del self._components[key] + logger.info(f"Unregistered component: {key}") + return True + + def get(self, key: str) -> BaseComponent | None: + """ + Get a component by its key. + + Args: + key: Component key (e.g., 'system:deep_research:query_analyzer') + + Returns: + The component or None if not found + """ + return self._components.get(key) + + def get_metadata(self, key: str) -> ComponentMetadata | None: + """ + Get metadata for a component. + + Args: + key: Component key + + Returns: + Component metadata or None if not found + """ + component = self._components.get(key) + return component.metadata if component else None + + def get_config(self, key: str) -> dict[str, Any] | None: + """ + Get the exported configuration for a component. + + Args: + key: Component key + + Returns: + Exported configuration dict or None if not found + """ + component = self._components.get(key) + return component.export_config() if component else None + + def list_all(self) -> list[str]: + """List all registered component keys.""" + return list(self._components.keys()) + + def list_metadata(self) -> list[ComponentMetadata]: + """Get metadata for all registered components.""" + return [comp.metadata for comp in self._components.values()] + + def list_by_type(self, component_type: ComponentType) -> list[ComponentMetadata]: + """ + List all components of a specific type. + + Args: + component_type: Type to filter by + + Returns: + List of component metadata + """ + keys = self._by_type.get(component_type, []) + return [self._components[k].metadata for k in keys if k in self._components] + + def list_by_tag(self, tag: str) -> list[ComponentMetadata]: + """ + List all components with a specific tag. + + Args: + tag: Tag to filter by + + Returns: + List of component metadata + """ + keys = self._by_tag.get(tag, []) + return [self._components[k].metadata for k in keys if k in self._components] + + def search( + self, + query: str | None = None, + component_type: ComponentType | None = None, + tags: list[str] | None = None, + ) -> list[ComponentMetadata]: + """ + Search for components matching criteria. + + Args: + query: Text to search in name/description + component_type: Filter by type + tags: Filter by tags (any match) + + Returns: + List of matching component metadata + """ + results: list[ComponentMetadata] = [] + + for component in self._components.values(): + metadata = component.metadata + + # Filter by type + if component_type and metadata.component_type != component_type: + continue + + # Filter by tags + if tags and not any(t in metadata.tags for t in tags): + continue + + # Filter by query + if query: + query_lower = query.lower() + if ( + query_lower not in metadata.name.lower() + and query_lower not in metadata.description.lower() + and query_lower not in metadata.key.lower() + ): + continue + + results.append(metadata) + + return results + + def export_all(self) -> dict[str, dict[str, Any]]: + """ + Export all components as JSON-serializable configs. + + Returns: + Dictionary mapping keys to exported configs + """ + return {key: comp.export_config() for key, comp in self._components.items()} + + def get_stats(self) -> dict[str, Any]: + """ + Get statistics about the registry. + + Returns: + Dictionary with registry statistics + """ + return { + "total_components": len(self._components), + "by_type": {t.value: len(keys) for t, keys in self._by_type.items()}, + "unique_tags": list(self._by_tag.keys()), + "tag_counts": {tag: len(keys) for tag, keys in self._by_tag.items()}, + } + + +# Global registry instance +component_registry = ComponentRegistry() + + +# Convenience functions +def register_component(component: BaseComponent, override: bool = False) -> None: + """Register a component in the global registry.""" + component_registry.register(component, override) + + +def get_component(key: str) -> BaseComponent | None: + """Get a component from the global registry.""" + return component_registry.get(key) + + +def get_component_config(key: str) -> dict[str, Any] | None: + """Get a component's exported config from the global registry.""" + return component_registry.get_config(key) + + +# Export +__all__ = [ + # Registry + "ComponentRegistry", + "component_registry", + # Convenience functions + "register_component", + "get_component", + "get_component_config", + # Base classes (re-exported from base.py) + "BaseComponent", + "NodeComponent", + "SubgraphComponent", + "StateSchemaComponent", + "PromptTemplateComponent", + "ComponentType", + "ComponentMetadata", +] diff --git a/service/app/agents/components/base.py b/service/app/agents/components/base.py new file mode 100644 index 00000000..dfa7d10a --- /dev/null +++ b/service/app/agents/components/base.py @@ -0,0 +1,184 @@ +""" +Base Component Module - Abstract base classes for reusable agent components. + +This module defines the component abstraction that allows system agents +to export reusable pieces (nodes, subgraphs, state schemas, prompts) +that can be composed into user-defined agents. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + + +class ComponentType(StrEnum): + """Types of reusable components.""" + + NODE = "node" # Reusable graph node + SUBGRAPH = "subgraph" # Complete subgraph that can be embedded + STATE_SCHEMA = "state_schema" # Reusable state field definitions + PROMPT_TEMPLATE = "prompt_template" # Reusable prompt template + REDUCER = "reducer" # Custom state reducer function + + +class ComponentMetadata(BaseModel): + """Metadata describing a registered component.""" + + key: str = Field( + description="Unique identifier: 'namespace:agent:component' (e.g., 'system:deep_research:query_analyzer')" + ) + name: str = Field(description="Human-readable name") + description: str = Field(description="Detailed description of what this component does") + component_type: ComponentType = Field(description="Type of component") + version: str = Field(default="1.0.0", description="Semantic version") + author: str = Field(default="Xyzen", description="Component author") + tags: list[str] = Field(default_factory=list, description="Tags for discovery") + + # Schema information + input_schema: dict[str, Any] | None = Field(default=None, description="JSON Schema for component inputs") + output_schema: dict[str, Any] | None = Field(default=None, description="JSON Schema for component outputs") + + # Dependencies + required_tools: list[str] = Field(default_factory=list, description="Tools required by this component") + required_components: list[str] = Field(default_factory=list, description="Other components this depends on") + + +class BaseComponent(ABC): + """ + Abstract base class for all reusable components. + + Components are building blocks that can be shared across agents. + Each component provides: + - Metadata for discovery and documentation + - Export functionality for JSON serialization + - Validation to ensure correct configuration + """ + + @property + @abstractmethod + def metadata(self) -> ComponentMetadata: + """Return component metadata.""" + ... + + @abstractmethod + def export_config(self) -> dict[str, Any]: + """ + Export component as JSON-serializable configuration. + + The exported config can be used in GraphConfig to instantiate + this component in a user-defined agent. + + Returns: + Dictionary containing the component's configuration + """ + ... + + def validate(self) -> list[str]: + """ + Validate component configuration. + + Override this method to add custom validation logic. + + Returns: + List of validation errors (empty if valid) + """ + return [] + + def get_example_usage(self) -> str | None: + """ + Return an example of how to use this component. + + Override to provide usage examples for documentation. + + Returns: + Example usage string or None + """ + return None + + +class NodeComponent(BaseComponent): + """Base class for reusable node components.""" + + @property + def component_type(self) -> ComponentType: + return ComponentType.NODE + + +class SubgraphComponent(BaseComponent): + """Base class for reusable subgraph components.""" + + @property + def component_type(self) -> ComponentType: + return ComponentType.SUBGRAPH + + @abstractmethod + def get_nodes(self) -> list[dict[str, Any]]: + """Return the list of node configurations for this subgraph.""" + ... + + @abstractmethod + def get_edges(self) -> list[dict[str, Any]]: + """Return the list of edge configurations for this subgraph.""" + ... + + @abstractmethod + def get_entry_point(self) -> str: + """Return the entry point node ID for this subgraph.""" + ... + + @abstractmethod + def get_exit_points(self) -> list[str]: + """Return the exit point node IDs for this subgraph.""" + ... + + +class StateSchemaComponent(BaseComponent): + """Base class for reusable state schema components.""" + + @property + def component_type(self) -> ComponentType: + return ComponentType.STATE_SCHEMA + + @abstractmethod + def get_fields(self) -> dict[str, dict[str, Any]]: + """ + Return state field definitions. + + Returns: + Dictionary mapping field names to StateFieldSchema configs + """ + ... + + +class PromptTemplateComponent(BaseComponent): + """Base class for reusable prompt template components.""" + + @property + def component_type(self) -> ComponentType: + return ComponentType.PROMPT_TEMPLATE + + @abstractmethod + def get_template(self) -> str: + """Return the Jinja2 template string.""" + ... + + @abstractmethod + def get_variables(self) -> list[str]: + """Return list of variables expected by the template.""" + ... + + +# Export +__all__ = [ + "ComponentType", + "ComponentMetadata", + "BaseComponent", + "NodeComponent", + "SubgraphComponent", + "StateSchemaComponent", + "PromptTemplateComponent", +] diff --git a/service/app/agents/factory.py b/service/app/agents/factory.py index 005c83c3..e53941e3 100644 --- a/service/app/agents/factory.py +++ b/service/app/agents/factory.py @@ -3,16 +3,26 @@ This module provides factory functions to instantiate the appropriate agent based on session configuration, agent type, and other parameters. + +Supports: +- graph_config: JSON-configured graph agent +- No config: Falls back to the built-in react system agent +- graph_config with metadata.system_agent_key: Uses specified system agent + +The default agent is the "react" system agent. """ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langgraph.graph.state import CompiledStateGraph from sqlmodel.ext.asyncio.session import AsyncSession +from app.agents.types import DynamicCompiledGraph, LLMFactory, SystemAgentInfo +from app.core.chat.agent_event_handler import AgentEventContext + if TYPE_CHECKING: from uuid import UUID @@ -26,16 +36,19 @@ logger = logging.getLogger(__name__) +# Default system agent key when no agent is specified +DEFAULT_SYSTEM_AGENT = "react" + async def create_chat_agent( db: AsyncSession, - agent_config: Agent | None, - topic: TopicModel, - user_provider_manager: ProviderManager, + agent_config: "Agent | None", + topic: "TopicModel", + user_provider_manager: "ProviderManager", provider_id: str | None, model_name: str | None, system_prompt: str, -) -> CompiledStateGraph: +) -> tuple[CompiledStateGraph[Any, None, Any, Any], AgentEventContext]: """ Create the appropriate agent for a chat session. @@ -52,88 +65,203 @@ async def create_chat_agent( system_prompt: System prompt for the agent Returns: - Compiled StateGraph ready for streaming execution - - Future Enhancement: - Support for different agent types based on agent_config.agent_type: - - "react" (default): Standard ReAct agent - - "plan_and_execute": Planning agent - - "custom_graph": DB-defined graph agent + Tuple of (CompiledStateGraph, AgentEventContext) for streaming execution """ - from app.agents.react_agent import ReActAgent from app.core.chat.langchain_tools import prepare_langchain_tools from app.repos.session import SessionRepository # Get session for configuration session_repo = SessionRepository(db) - session: Session | None = await session_repo.get_session_by_id(topic.session_id) + session: "Session | None" = await session_repo.get_session_by_id(topic.session_id) # Check if built-in search is enabled google_search_enabled: bool = session.google_search_enabled if session else False - # Create LangChain model WITH provider-side web search binding. - # This ensures OpenAI gets `web_search_preview` and Gemini/Vertex gets `google_search`. - llm: BaseChatModel = await user_provider_manager.create_langchain_model( - provider_id, - model=model_name, - google_search_enabled=google_search_enabled, + # Prepare tools from MCP servers + session_id: "UUID | None" = topic.session_id if topic else None + tools: list["BaseTool"] = await prepare_langchain_tools(db, agent_config, session_id) + + # Determine how to execute this agent + agent_type_str, system_key = _resolve_agent_config(agent_config) + + # For frontend event tracking, use the actual system key (react, deep_research) + # instead of generic "system" so the UI can distinguish between agent types + event_agent_type = system_key if system_key else agent_type_str + + # Create event context for tracking + event_ctx = AgentEventContext( + agent_id=str(agent_config.id) if agent_config else "default", + agent_name=agent_config.name if agent_config else "Default Agent", + agent_type=event_agent_type, ) - # Prepare tools from MCP servers - session_id: UUID | None = topic.session_id if topic else None - tools: list[BaseTool] = await prepare_langchain_tools(db, agent_config, session_id) - - # Determine agent type (future: read from agent_config.agent_type) - agent_type: str = _get_agent_type(agent_config) - - # Create the appropriate agent - if agent_type == "react": - agent = ReActAgent( - llm=llm, - tools=tools, - system_prompt=system_prompt, - google_search_enabled=google_search_enabled, + # Create LLM factory for graph and system agents + async def create_llm(**kwargs: Any) -> "BaseChatModel": + override_model = kwargs.get("model") or model_name + override_temp = kwargs.get("temperature") + + # Build kwargs conditionally to avoid passing None values + # (some providers like Google don't accept temperature=None) + model_kwargs: dict[str, Any] = { + "model": override_model, + "google_search_enabled": google_search_enabled, + } + if override_temp is not None: + model_kwargs["temperature"] = override_temp + + return await user_provider_manager.create_langchain_model( + provider_id, + **model_kwargs, ) - compiled_graph: CompiledStateGraph = agent.build_graph() - logger.info(f"Created ReAct agent with {len(tools)} tools, google_search={google_search_enabled}") - return compiled_graph - # Future: Add more agent types here - # elif agent_type == "plan_and_execute": - # return _create_plan_and_execute_agent(...) - # elif agent_type == "custom_graph": - # return await _create_custom_graph_agent(db, agent_config, ...) + # Route to appropriate agent builder based on type + if agent_type_str == "graph": + return await _create_graph_agent( + agent_config, + create_llm, + tools, + event_ctx, + ) + + # System agent (includes react, deep_research, etc.) + return await _create_system_agent( + system_key, + agent_config, + create_llm, + tools, + system_prompt, + google_search_enabled, + event_ctx, + ) + + +async def _create_graph_agent( + agent_config: "Agent | None", + llm_factory: LLMFactory, + tools: list["BaseTool"], + event_ctx: AgentEventContext, +) -> tuple[DynamicCompiledGraph, AgentEventContext]: + """Create a JSON-configured graph agent.""" + from app.agents.graph_builder import GraphBuilder + from app.schemas.graph_config import GraphConfig + + if not agent_config or not agent_config.graph_config: + raise ValueError("Graph agent requires agent_config with graph_config") + + # Parse and validate graph config + graph_config = GraphConfig.model_validate(agent_config.graph_config) + + # Build tool registry + tool_registry = {t.name: t for t in tools} + + # Create graph builder + builder = GraphBuilder( + config=graph_config, + llm_factory=llm_factory, + tool_registry=tool_registry, + ) + + # Build graph + compiled_graph = builder.build() + + logger.info(f"Created graph agent '{agent_config.name}' with {len(graph_config.nodes)} nodes") + return compiled_graph, event_ctx + + +async def _create_system_agent( + system_key: str, + agent_config: "Agent | None", + llm_factory: LLMFactory, + tools: list["BaseTool"], + system_prompt: str, + google_search_enabled: bool, + event_ctx: AgentEventContext, +) -> tuple[CompiledStateGraph[Any, None, Any, Any], AgentEventContext]: + """ + Create a Python-coded system agent. + + Args: + system_key: System agent key (e.g., "react", "deep_research") + agent_config: Optional agent configuration + llm_factory: Factory function to create LLM + tools: List of tools available to the agent + system_prompt: System prompt for the agent + google_search_enabled: Whether Google search is enabled + event_ctx: Event context for tracking + + Returns: + Tuple of (CompiledStateGraph, AgentEventContext) + """ + from app.agents.system import system_agent_registry + + # Get system agent instance + llm = await llm_factory() + system_agent = system_agent_registry.get_instance( + system_key, + llm=llm, + tools=tools, + ) + + if not system_agent: + raise ValueError(f"System agent not found: {system_key}") + + # Special handling for react agent - pass system_prompt and google_search + if system_key == "react": + from app.agents.system.react import ReActAgent + + if isinstance(system_agent, ReActAgent): + system_agent.system_prompt = system_prompt + system_agent.google_search_enabled = google_search_enabled - # Default fallback to ReAct - logger.warning(f"Unknown agent type '{agent_type}', falling back to ReAct") - agent = ReActAgent(llm=llm, tools=tools, system_prompt=system_prompt) - return agent.build_graph() + # Build graph + compiled_graph = system_agent.build_graph() + logger.info(f"Created system agent '{system_key}' with {len(tools)} tools") + return compiled_graph, event_ctx -def _get_agent_type(agent_config: Agent | None) -> str: + +def _resolve_agent_config(agent_config: "Agent | None") -> tuple[str, str]: """ - Determine the agent type from configuration. + Resolve how to execute an agent based on its graph_config. + + Resolution logic: + 1. No agent_config → use react fallback + 2. Agent has graph_config with metadata.system_agent_key → use that system agent + 3. Agent has graph_config → use graph agent + 4. Agent has no graph_config → use react fallback Args: - agent_config: Agent configuration + agent_config: Agent configuration (may be None) Returns: - Agent type string (default: "react") + Tuple of (agent_type_for_events, system_key) + - agent_type_for_events: "graph" or "system" for event tracking + - system_key: System agent key (e.g., "react", "deep_research"), empty for pure graph """ if agent_config is None: - return "react" + # Default to react system agent + return "system", DEFAULT_SYSTEM_AGENT + + # Check graph_config + if agent_config.graph_config: + # Check for system_agent_key in metadata (uses system agent as base) + metadata = agent_config.graph_config.get("metadata", {}) + system_key = metadata.get("system_agent_key") + if system_key: + return "system", system_key + # Pure graph agent + return "graph", "" - # Future: Read from agent_config.agent_type field - # For now, always use react agent - return "react" + # No graph_config = use react fallback + return "system", DEFAULT_SYSTEM_AGENT async def create_agent_from_builtin( builtin_name: str, - user_provider_manager: ProviderManager, + user_provider_manager: "ProviderManager", provider_id: str | None, model_name: str | None, -) -> CompiledStateGraph | None: +) -> CompiledStateGraph[Any, None, Any, Any] | None: """ Create an agent from the builtin registry. @@ -154,9 +282,81 @@ async def create_agent_from_builtin( return None try: - graph: CompiledStateGraph = agent.build_graph() + graph: CompiledStateGraph[Any, None, Any, Any] = agent.build_graph() logger.info(f"Created builtin agent '{builtin_name}'") return graph except Exception as e: logger.error(f"Failed to build builtin agent '{builtin_name}': {e}") return None + + +async def create_system_agent_graph( + system_key: str, + user_provider_manager: "ProviderManager", + provider_id: str | None, + model_name: str | None, + tools: list["BaseTool"] | None = None, +) -> tuple[CompiledStateGraph[Any, None, Any, Any], AgentEventContext] | None: + """ + Create a system agent graph directly by key. + + Useful for invoking system agents outside of the normal chat flow. + + Args: + system_key: System agent key (e.g., "deep_research") + user_provider_manager: Provider manager for LLM access + provider_id: Provider ID to use + model_name: Model name to use + tools: Optional tools to provide + + Returns: + Tuple of (CompiledStateGraph, AgentEventContext) or None if not found + """ + from app.agents.system import system_agent_registry + + # Create LLM + llm = await user_provider_manager.create_langchain_model( + provider_id, + model=model_name, + ) + + # Get system agent + system_agent = system_agent_registry.get_instance( + system_key, + llm=llm, + tools=tools or [], + ) + + if not system_agent: + logger.warning(f"System agent '{system_key}' not found") + return None + + # Create event context + # Use the actual system_key (e.g., "deep_research") as agent_type + # so frontend can distinguish between different system agents + event_ctx = AgentEventContext( + agent_id=system_key, + agent_name=system_agent.name, + agent_type=system_key, + ) + + # Build graph + try: + graph = system_agent.build_graph() + logger.info(f"Created system agent graph '{system_key}'") + return graph, event_ctx + except Exception as e: + logger.error(f"Failed to build system agent '{system_key}': {e}") + return None + + +def list_available_system_agents() -> list[SystemAgentInfo]: + """ + List all available system agents. + + Returns: + List of system agent metadata dictionaries + """ + from app.agents.system import system_agent_registry + + return system_agent_registry.get_all_metadata() # type: ignore[return-value] diff --git a/service/app/agents/graph_builder.py b/service/app/agents/graph_builder.py new file mode 100644 index 00000000..fd0b59d2 --- /dev/null +++ b/service/app/agents/graph_builder.py @@ -0,0 +1,804 @@ +""" +Graph Builder - Compiles JSON GraphConfig into LangGraph CompiledStateGraph. + +This module provides the core functionality to transform JSON-based agent +configurations into executable LangGraph workflows. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Callable, Hashable + +from jinja2 import Template +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langgraph.graph import END, START, StateGraph +from pydantic import BaseModel, ConfigDict, Field, create_model + +from app.agents.types import ( + DynamicCompiledGraph, + DynamicStateGraph, + LLMFactory, + NodeFunction, + RouterFunction, + StateDict, +) +from app.schemas.graph_config import ( + EdgeCondition, + GraphConfig, + GraphEdgeConfig, + GraphNodeConfig, + NodeType, + ReducerType, + StructuredOutputSchema, + validate_graph_config, +) + +if TYPE_CHECKING: + from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +def _extract_content_str(content: str | list[str | dict[str, Any]] | Any) -> str: + """Extract string content from LLM response content field. + + Note: LangChain's content type includes Unknown due to incomplete stubs. + """ + if isinstance(content, str): + return content + if isinstance(content, list): + # Handle list of content blocks (e.g., multimodal responses) + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and "text" in item: + parts.append(str(item["text"])) + return "".join(parts) + return str(content) + + +# --- State Reducers --- + + +def append_reducer(existing: list[Any] | None, new: Any) -> list[Any]: + """Append new value(s) to existing list.""" + if existing is None: + existing = [] + if isinstance(new, list): + return existing + new + return existing + [new] + + +def merge_reducer(existing: dict[str, Any] | None, new: dict[str, Any] | None) -> dict[str, Any]: + """Merge new dict into existing dict.""" + if existing is None: + existing = {} + if new is None: + return existing + return {**existing, **new} + + +def messages_reducer( + existing: list[BaseMessage] | None, new: list[BaseMessage] | BaseMessage | None +) -> list[BaseMessage]: + """Special reducer for message lists that handles deduplication.""" + if existing is None: + existing = [] + if new is None: + return existing + if isinstance(new, list): + return existing + new + return existing + [new] + + +REDUCERS: dict[ReducerType, Callable[..., Any]] = { + ReducerType.APPEND: append_reducer, + ReducerType.MERGE: merge_reducer, + ReducerType.MESSAGES: messages_reducer, +} + + +# --- Dynamic State Builder --- +def build_state_class(config: GraphConfig) -> type[BaseModel]: + """ + Dynamically create a Pydantic state class from GraphConfig. + + The state class will have: + - All fields defined in state_schema + - Built-in 'messages' field (list[BaseMessage]) + - Built-in 'execution_context' field (dict) + """ + fields: dict[str, tuple[Any, Any]] = {} + + # Built-in fields (always present) + fields["messages"] = (list[BaseMessage], Field(default_factory=list)) + fields["execution_context"] = (dict[str, Any], Field(default_factory=dict)) + + # Type mapping for schema fields + type_map: dict[str, Any] = { + "string": str, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list[Any], + "dict": dict[str, Any], + "any": Any, + "messages": list[BaseMessage], + } + + for field_name, field_schema in config.state_schema.fields.items(): + python_type = type_map.get(field_schema.type, Any) + default_value = field_schema.default + fields[field_name] = (python_type | None, Field(default=default_value)) + + # Create dynamic model with arbitrary types allowed + DynamicState: type[BaseModel] = create_model( + "DynamicGraphState", + __config__=ConfigDict(arbitrary_types_allowed=True), # type: ignore[call-overload] + **fields, # type: ignore[arg-type] + ) + return DynamicState + + +# --- Graph Builder --- +class GraphBuilder: + """ + Builds LangGraph from JSON GraphConfig. + + This class compiles a JSON-based agent configuration into an executable + LangGraph workflow with proper state management, node execution, and routing. + """ + + config: GraphConfig + llm_factory: LLMFactory + tool_registry: dict[str, "BaseTool"] + context: dict[str, Any] + state_class: type[BaseModel] + _template_cache: dict[str, Template] + + def __init__( + self, + config: GraphConfig, + llm_factory: LLMFactory, + tool_registry: dict[str, "BaseTool"], + context: dict[str, Any] | None = None, + ) -> None: + """ + Initialize the GraphBuilder. + + Args: + config: GraphConfig defining the agent workflow + llm_factory: Factory function to create LLM instances + tool_registry: Dictionary mapping tool names to BaseTool instances + context: Optional runtime context passed to templates + """ + self.config = config + self.llm_factory = llm_factory + self.tool_registry = tool_registry + self.context = context or {} + + # Validate configuration + errors = validate_graph_config(config) + if errors: + raise ValueError(f"Invalid graph configuration: {errors}") + + # Build dynamic state class + self.state_class = build_state_class(config) + + # Cache for compiled templates + self._template_cache = {} + + def build(self) -> DynamicCompiledGraph: + """ + Build and compile the LangGraph from configuration. + + Returns: + CompiledStateGraph ready for execution + """ + logger.info(f"Building graph with {len(self.config.nodes)} nodes") + + # Create graph with dynamic state + graph: DynamicStateGraph = StateGraph(self.state_class) + + # Add all nodes + for node_config in self.config.nodes: + node_fn = self._build_node(node_config) + graph.add_node(node_config.id, node_fn) # type: ignore[arg-type] + logger.debug(f"Added node: {node_config.id} ({node_config.type})") + + # Add edges + self._add_edges(graph) + + # Compile and return + compiled: DynamicCompiledGraph = graph.compile() + logger.info("Graph compiled successfully") + return compiled + + def _get_template(self, template_str: str) -> Template: + """Get or create a cached Jinja2 template.""" + if template_str not in self._template_cache: + self._template_cache[template_str] = Template(template_str) + return self._template_cache[template_str] + + def _state_to_dict(self, state: StateDict | BaseModel) -> dict[str, Any]: + """Convert state to dict, handling both dict and Pydantic model inputs.""" + if isinstance(state, BaseModel): + return state.model_dump() + return dict(state) if state else {} + + def _format_messages_for_prompt(self, messages: list[BaseMessage]) -> str: + """Format a list of messages into a string for prompt templates.""" + if not messages: + return "" + + formatted_parts: list[str] = [] + for msg in messages: + role = msg.__class__.__name__.replace("Message", "") # HumanMessage -> Human + content = msg.content if hasattr(msg, "content") else str(msg) + if isinstance(content, list): + # Handle multimodal content + content = " ".join(str(c.get("text", c)) if isinstance(c, dict) else str(c) for c in content) + formatted_parts.append(f"{role}: {content}") + + return "\n".join(formatted_parts) + + def _render_template(self, template_str: str, state: StateDict | BaseModel) -> str: + """Render a template with state and context. + + Supports both Jinja2 syntax ({{ variable }}) and Python format strings ({variable}). + This allows compatibility with existing prompts that use {messages} style placeholders. + """ + import datetime + + template = self._get_template(template_str) + state_dict = self._state_to_dict(state) + + # First pass: Jinja2 rendering + rendered = template.render( + state=state_dict, + prompt_templates=self.config.prompt_templates, + context=self.context, + ) + + # Second pass: Python format string for backward compatibility + # Build format args from state, with special handling for messages + format_args: dict[str, Any] = {} + + # Format messages as a readable string + messages = state_dict.get("messages", []) + if messages and len(messages) > 0: + if isinstance(messages[0], BaseMessage): + format_args["messages"] = self._format_messages_for_prompt(messages) + else: + format_args["messages"] = str(messages) + else: + format_args["messages"] = "" + + # Add current date + format_args["date"] = datetime.datetime.now().strftime("%Y-%m-%d") + + # Add other state fields + for key, value in state_dict.items(): + if key not in format_args and not isinstance(value, (list, dict)): + format_args[key] = str(value) if value is not None else "" + + # Apply format string substitution + # Use a safer approach that handles curly braces in JSON content + import re + + def replace_placeholder(match: re.Match[str]) -> str: + key = match.group(1) + return str(format_args.get(key, match.group(0))) + + # Match {word} but not {{ or }} (escaped braces) + rendered = re.sub(r"\{(\w+)\}", replace_placeholder, rendered) + + logger.debug(f"Rendered template ({len(rendered)} chars), messages count: {len(messages)}") + + return rendered + + def _build_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a node function from configuration.""" + match config.type: + case NodeType.LLM: + return self._build_llm_node(config) + case NodeType.TOOL: + return self._build_tool_node(config) + case NodeType.ROUTER: + return self._build_router_node(config) + case NodeType.SUBAGENT: + return self._build_subagent_node(config) + case NodeType.PARALLEL: + return self._build_parallel_node(config) + case NodeType.TRANSFORM: + return self._build_transform_node(config) + case NodeType.HUMAN: + return self._build_human_node(config) + case _: + raise ValueError(f"Unknown node type: {config.type}") + + def _build_llm_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build an LLM reasoning node. + + Supports structured output via `structured_output` config: + - Dynamically creates Pydantic model from JSON schema + - Uses LangChain's with_structured_output() for parsing + - Extracts `message_key` field for user-facing message + """ + llm_config = config.llm_config + if not llm_config: + raise ValueError(f"LLM node '{config.id}' missing llm_config") + + # Pre-build structured output model if configured + structured_model: type[BaseModel] | None = None + if llm_config.structured_output: + structured_model = self._build_structured_output_model(config.id, llm_config.structured_output) + + async def llm_node(state: StateDict) -> StateDict: + logger.info(f"[LLM Node: {config.id}] Starting execution") + + # Get state as dict for inspection + state_dict = self._state_to_dict(state) + messages_in_state = state_dict.get("messages", []) + logger.info(f"[LLM Node: {config.id}] Input messages count: {len(messages_in_state)}") + + # Render prompt template + prompt = self._render_template(llm_config.prompt_template, state) + logger.info(f"[LLM Node: {config.id}] Rendered prompt length: {len(prompt)}") + + # Get LLM with optional overrides + llm = await self.llm_factory( + model=llm_config.model_override, + temperature=llm_config.temperature_override, + ) + + # If structured output is configured, use with_structured_output + if structured_model: + llm = llm.with_structured_output(structured_model) + + # Bind tools if enabled (only for non-structured output) + elif llm_config.tools_enabled and self.tool_registry: + tools_to_bind = list(self.tool_registry.values()) + if llm_config.tool_filter: + tools_to_bind = [t for t in tools_to_bind if t.name in llm_config.tool_filter] + if tools_to_bind: + llm = llm.bind_tools(tools_to_bind) + + # Build messages - handle both dict and Pydantic state + state_dict = self._state_to_dict(state) + messages: list[BaseMessage] = state_dict.get("messages", []) + messages = messages + [HumanMessage(content=prompt)] + + # Invoke LLM + response = await llm.ainvoke(messages) + + # Handle structured output + if structured_model and isinstance(response, BaseModel): + # Convert structured response to dict for state update + response_dict = response.model_dump() + + # Determine what to show in messages + user_message = "" + + # 1. Check conditional message selection first + if llm_config.message_key_condition: + cond = llm_config.message_key_condition + condition_field = cond.get("condition_field") + true_key = cond.get("true_key") + false_key = cond.get("false_key") + + if condition_field and true_key and false_key: + condition_value = response_dict.get(condition_field, False) + selected_key = true_key if condition_value else false_key + user_message = str(response_dict.get(selected_key, "")) + logger.debug( + f"LLM node {config.id}: {condition_field}={condition_value}, " + f"using message from '{selected_key}'" + ) + + # 2. Fall back to simple message_key + elif llm_config.message_key and llm_config.message_key in response_dict: + user_message = str(response_dict[llm_config.message_key]) + + # 3. Fall back to first non-empty string field + else: + user_message = next((str(v) for v in response_dict.values() if isinstance(v, str) and v), "") + + logger.info(f"[LLM Node: {config.id}] Structured output completed") + + # Build node metadata for frontend display and persistence + node_metadata = { + "node_id": config.id, + "node_name": config.name or config.id, + "node_type": "llm", + "is_intermediate": config.id not in ("final_report_generation", "agent", "model"), + "structured_output": response_dict, # Include full structured data + } + + # Build agent state for persistence (includes current node output + context) + agent_state = { + "current_node": config.id, + "node_outputs": {config.id: response_dict}, + } + + # Return all fields from structured output + message for chat + result: StateDict = { + llm_config.output_key: response_dict, # Full structured data + "messages": [ + AIMessage( + content=user_message, + additional_kwargs={ + "node_metadata": node_metadata, + "agent_state": agent_state, + }, + ) + ], + } + # Also set individual fields in state for routing conditions + result.update(response_dict) + return result + + # Handle regular text response + content_str = _extract_content_str(getattr(response, "content", response)) + + logger.info(f"[LLM Node: {config.id}] Text output completed") + + # Build node metadata for regular text output + node_metadata = { + "node_id": config.id, + "node_name": config.name or config.id, + "node_type": "llm", + "is_intermediate": config.id not in ("final_report_generation", "agent", "model"), + } + + # For intermediate nodes, truncate output to save space in metadata + # For final nodes, keep full content for persistence to message.content + is_intermediate = node_metadata["is_intermediate"] + output_for_state = content_str[:500] if is_intermediate and len(content_str) > 500 else content_str + + agent_state = { + "current_node": config.id, + "node_outputs": {config.id: output_for_state}, + } + + return { + llm_config.output_key: content_str, + "messages": [ + AIMessage( + content=content_str, + additional_kwargs={ + "node_metadata": node_metadata, + "agent_state": agent_state, + }, + ) + ], + } + + return llm_node + + def _build_structured_output_model(self, node_id: str, schema: StructuredOutputSchema) -> type[BaseModel]: + """Build a Pydantic model from JSON-defined structured output schema.""" + + # Type mapping from schema types to Python types + type_map: dict[str, Any] = { + "string": str, + "str": str, + "bool": bool, + "int": int, + "float": float, + "list": list[Any], + "dict": dict[str, Any], + } + + fields: dict[str, tuple[Any, Any]] = {} + for field_name, field_def in schema.fields.items(): + python_type = type_map.get(field_def.type, Any) + if field_def.required: + fields[field_name] = (python_type, Field(description=field_def.description)) + else: + fields[field_name] = ( + python_type | None, + Field(default=field_def.default, description=field_def.description), + ) + + # Create dynamic model + model_name = f"{node_id.title().replace('_', '')}Output" + return create_model( + model_name, + __doc__=schema.description, + **fields, # type: ignore[arg-type] + ) + + def _build_tool_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a tool execution node.""" + tool_config = config.tool_config + if not tool_config: + raise ValueError(f"Tool node '{config.id}' missing tool_config") + + async def tool_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Tool node: {config.id} (tool: {tool_config.tool_name})") + + # Get tool + tool = self.tool_registry.get(tool_config.tool_name) + if not tool: + raise ValueError(f"Tool not found: {tool_config.tool_name}") + + # Render arguments + args: dict[str, str] = {} + for key, template_str in tool_config.arguments_template.items(): + args[key] = self._render_template(template_str, state) + + # Execute tool with timeout + try: + result = await asyncio.wait_for( + tool.ainvoke(args), + timeout=tool_config.timeout_seconds, + ) + except asyncio.TimeoutError: + logger.error(f"Tool {tool_config.tool_name} timed out") + result = {"error": f"Tool execution timed out after {tool_config.timeout_seconds}s"} + + logger.debug(f"Tool node {config.id} completed") + return {tool_config.output_key: result} + + return tool_node + + def _build_router_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a routing/branching node.""" + router_config = config.router_config + if not router_config: + raise ValueError(f"Router node '{config.id}' missing router_config") + + async def router_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Router node: {config.id}") + + if router_config.strategy == "condition": + # Evaluate conditions in order + for condition in router_config.conditions: + if self._evaluate_condition(condition, state): + logger.debug(f"Router {config.id} matched condition, routing to {condition.target}") + return {"_next_node": condition.target} + + logger.debug(f"Router {config.id} using default route: {router_config.default_route}") + return {"_next_node": router_config.default_route} + + elif router_config.strategy == "llm": + # Use LLM to decide route + if not router_config.llm_prompt: + raise ValueError("LLM routing strategy requires llm_prompt") + + prompt = self._render_template(router_config.llm_prompt, state) + llm = await self.llm_factory() + response = await llm.ainvoke(prompt) + route = _extract_content_str(response.content).strip() + + if route in router_config.routes: + return {"_next_node": route} + return {"_next_node": router_config.default_route} + + else: + return {"_next_node": router_config.default_route} + + return router_node + + def _build_subagent_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a subagent invocation node.""" + subagent_config = config.subagent_config + if not subagent_config: + raise ValueError(f"Subagent node '{config.id}' missing subagent_config") + + async def subagent_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Subagent node: {config.id} (agent: {subagent_config.agent_ref})") + + # Import here to avoid circular dependency + from app.agents.system import system_agent_registry + + # Try to get system agent + system_agent = system_agent_registry.get_instance(subagent_config.agent_ref) + + if system_agent: + # Build subagent graph + subagent_graph = system_agent.build_graph() + + # Map input state + subagent_input: StateDict = {} + for child_key, parent_template in subagent_config.input_mapping.items(): + subagent_input[child_key] = self._render_template(parent_template, state) + + # Execute subagent + try: + result = await asyncio.wait_for( + subagent_graph.ainvoke(subagent_input), + timeout=subagent_config.timeout_seconds, + ) + except asyncio.TimeoutError: + logger.error(f"Subagent {subagent_config.agent_ref} timed out") + result = {"error": "Subagent execution timed out"} + + # Map output state + output: StateDict = {} + for parent_key, child_key in subagent_config.output_mapping.items(): + if child_key in result: + output[parent_key] = result[child_key] + + return output + else: + # TODO: Handle user-defined agent by UUID + logger.warning(f"Subagent not found: {subagent_config.agent_ref}") + return {"error": f"Subagent not found: {subagent_config.agent_ref}"} + + return subagent_node + + def _build_parallel_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a parallel execution node.""" + parallel_config = config.parallel_config + if not parallel_config: + raise ValueError(f"Parallel node '{config.id}' missing parallel_config") + + async def parallel_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Parallel node: {config.id}") + + # This is a placeholder - actual parallel execution requires + # more complex handling with LangGraph's parallel capabilities + logger.warning("Parallel node execution not fully implemented") + return {parallel_config.merge_key: []} + + return parallel_node + + def _build_transform_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a data transformation node.""" + transform_config = config.transform_config + if not transform_config: + raise ValueError(f"Transform node '{config.id}' missing transform_config") + + async def transform_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Transform node: {config.id}") + + if transform_config.template: + # Use Jinja2 template + result: Any = self._render_template(transform_config.template, state) + elif transform_config.expression: + # Evaluate Python expression (restricted context) + # WARNING: This should be sandboxed in production + try: + result = eval(transform_config.expression, {"state": state, "__builtins__": {}}) + except Exception as e: + logger.error(f"Transform expression error: {e}") + result = None + else: + result = None + + return {transform_config.output_key: result} + + return transform_node + + def _build_human_node(self, config: GraphNodeConfig) -> NodeFunction: + """Build a human-in-the-loop node.""" + human_config = config.human_config + if not human_config: + raise ValueError(f"Human node '{config.id}' missing human_config") + + async def human_node(state: StateDict) -> StateDict: + logger.debug(f"Executing Human node: {config.id}") + # Human nodes typically pause execution and wait for external input + # This is handled at a higher level in the execution framework + return {"_human_checkpoint": True, "_human_prompt": human_config.prompt_template} + + return human_node + + def _add_edges(self, graph: DynamicStateGraph) -> None: + """Add all edges to the graph.""" + # Group edges by source node for conditional edges + edges_by_source: dict[str, list[GraphEdgeConfig]] = {} + for edge in self.config.edges: + if edge.from_node not in edges_by_source: + edges_by_source[edge.from_node] = [] + edges_by_source[edge.from_node].append(edge) + + # Process edges for each source + for from_node, edges in edges_by_source.items(): + # Check if any edges have conditions + conditional_edges = [e for e in edges if e.condition] + unconditional_edges = [e for e in edges if not e.condition] + + if from_node == "START": + # Entry point + if unconditional_edges: + graph.add_edge(START, unconditional_edges[0].to_node) + elif conditional_edges: + # Conditional entry (unusual but supported) + graph.add_conditional_edges( + START, + self._build_conditional_router(conditional_edges), + {e.to_node: e.to_node for e in conditional_edges}, + ) + elif conditional_edges: + # Add conditional edges + # Build condition map, handling END specially + condition_map: dict[Hashable, str] = {} + for e in conditional_edges: + target = e.condition.target # type: ignore[union-attr] + to_node = e.to_node + # Map to langgraph END constant if target is "END" + if to_node == "END": + condition_map[target] = END # type: ignore[assignment] + else: + condition_map[target] = to_node + + if unconditional_edges: + condition_map["default"] = unconditional_edges[0].to_node + else: + condition_map["default"] = END # type: ignore[assignment] + + graph.add_conditional_edges( + from_node, + self._build_conditional_router(conditional_edges), + condition_map, + ) + elif unconditional_edges: + # Simple unconditional edge + to_node = unconditional_edges[0].to_node + if to_node == "END": + graph.add_edge(from_node, END) + else: + graph.add_edge(from_node, to_node) + + def _build_conditional_router(self, edges: list[GraphEdgeConfig]) -> RouterFunction: + """Build a routing function for conditional edges.""" + + def router(state: StateDict) -> str: + # Sort by priority (higher first) + sorted_edges = sorted(edges, key=lambda e: e.priority, reverse=True) + + for edge in sorted_edges: + if edge.condition and self._evaluate_condition(edge.condition, state): + return edge.condition.target + + return "default" + + return router + + def _evaluate_condition(self, condition: EdgeCondition, state: StateDict) -> bool: + """Evaluate a condition against the current state.""" + state_dict = self._state_to_dict(state) + value = state_dict.get(condition.state_key) + + match condition.operator: + case "eq": + return value == condition.value + case "neq": + return value != condition.value + case "contains": + return condition.value in value if value else False + case "not_contains": + return condition.value not in value if value else True + case "gt": + return value > condition.value if value is not None else False + case "gte": + return value >= condition.value if value is not None else False + case "lt": + return value < condition.value if value is not None else False + case "lte": + return value <= condition.value if value is not None else False + case "in": + return value in condition.value if condition.value else False + case "not_in": + return value not in condition.value if condition.value else True + case "truthy": + return bool(value) + case "falsy": + return not bool(value) + case "matches": + import re + + return bool(re.match(str(condition.value), str(value))) if value else False + case _: + return False + + +# Export +__all__ = ["GraphBuilder", "build_state_class"] diff --git a/service/app/agents/react_agent.py b/service/app/agents/react_agent.py deleted file mode 100644 index 59506d6b..00000000 --- a/service/app/agents/react_agent.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -ReAct Agent - Default tool-calling agent for chat conversations. - -This module provides the standard ReAct (Reasoning + Acting) agent that uses -LangChain's create_agent for tool-calling conversations. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any - -from langchain.agents import create_agent -from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool - -from .base_graph_agent import BaseBuiltinGraphAgent - -if TYPE_CHECKING: - from langgraph.graph.state import CompiledStateGraph - -logger = logging.getLogger(__name__) - - -class ReActAgent(BaseBuiltinGraphAgent): - """ - Default ReAct agent for tool-calling conversations. - - Uses LangGraph's prebuilt create_react_agent which implements - the ReAct pattern: Reasoning + Acting with tool calls. - - This agent: - - Processes user messages - - Decides whether to use tools or respond directly - - Executes tools and incorporates results - - Generates final responses - - Supports combining provider-side tools (like Google Search) with - client-side tools (like MCP tools) by binding them together. - - Attributes: - llm: LangChain chat model for reasoning - tools: List of tools the agent can use - system_prompt: System prompt to guide agent behavior - google_search_enabled: Whether to enable Google's builtin search - """ - - llm: BaseChatModel - tools: list[BaseTool] - system_prompt: str - google_search_enabled: bool - - def __init__( - self, - llm: BaseChatModel, - tools: list[BaseTool], - system_prompt: str = "", - google_search_enabled: bool = False, - name: str = "ReAct Agent", - description: str = "Default tool-calling agent using ReAct pattern", - version: str = "1.0.0", - ) -> None: - """ - Initialize the ReAct agent. - - Args: - llm: LangChain chat model for reasoning - tools: List of tools the agent can use - system_prompt: System prompt to guide agent behavior - google_search_enabled: Enable Google's builtin web search - name: Human-readable name - description: Description of agent capabilities - version: Version string - """ - super().__init__( - name=name, - description=description, - version=version, - capabilities=["tool-calling", "reasoning", "multi-turn-conversation"], - tags=["default", "react", "chat"], - ) - self.llm = llm - self.tools = tools - self.system_prompt = system_prompt - self.google_search_enabled = google_search_enabled - - def build_graph(self) -> "CompiledStateGraph": - """ - Build the ReAct agent graph using LangGraph's prebuilt implementation. - - When google_search_enabled is True, binds both the google_search - provider tool and MCP tools together to the model. - - Returns: - Compiled StateGraph ready for execution - """ - logger.info(f"Building ReAct agent with {len(self.tools)} tools, google_search={self.google_search_enabled}") - - # Combine all tools for binding - # MCP tools (client-side) are passed as-is - # Provider-side web search is bound at model creation time. - all_tools: list[BaseTool] = list(self.tools) - - # Use LangChain's create_agent (replacement for deprecated create_react_agent) - # Pass all tools together so they're bound in a single call - agent: CompiledStateGraph = create_agent( - model=self.llm, - tools=all_tools, - system_prompt=self.system_prompt if self.system_prompt else None, - ) - - logger.debug("Agent graph built successfully") - return agent - - def get_state_schema(self) -> dict[str, Any]: - """ - Return the state schema for ReAct agent. - - The prebuilt create_react_agent uses a standard messages-based schema. - - Returns: - State schema definition - """ - return { - "messages": "list[BaseMessage] - Conversation messages", - } - - def supports_streaming(self) -> bool: - """ReAct agent supports streaming.""" - return True - - def get_required_tools(self) -> list[str]: - """Return names of tools configured for this agent.""" - tool_names = [tool.name for tool in self.tools] - return tool_names - - -# Note: This agent is NOT auto-discovered because it requires runtime parameters -# (llm, tools). It's instantiated via the factory instead. diff --git a/service/app/agents/system/__init__.py b/service/app/agents/system/__init__.py new file mode 100644 index 00000000..bad5e7d8 --- /dev/null +++ b/service/app/agents/system/__init__.py @@ -0,0 +1,219 @@ +""" +System Agent Registry - Central registry for Python-coded system agents. + +This module provides automatic discovery and registration of system agents. +System agents are pre-built, Python-implemented agents that provide +complex workflows like deep research, code analysis, etc. +""" + +from __future__ import annotations + +import importlib +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from .base import BaseSystemAgent + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +class SystemAgentRegistry: + """ + Registry for system agents. + + This registry: + - Automatically discovers system agents in subdirectories + - Provides access to system agent classes and instances + - Manages component registration from system agents + """ + + def __init__(self) -> None: + self._agent_classes: dict[str, type[BaseSystemAgent]] = {} + self._instances: dict[str, BaseSystemAgent] = {} + self._initialized: bool = False + + def register_class(self, key: str, agent_class: type[BaseSystemAgent], override: bool = False) -> None: + """ + Register a system agent class. + + Args: + key: Unique system key (e.g., "deep_research") + agent_class: The agent class to register + override: If True, allow overwriting existing registration + """ + if key in self._agent_classes and not override: + raise ValueError(f"System agent '{key}' already registered") + + self._agent_classes[key] = agent_class + logger.info(f"Registered system agent class: {key}") + + def get_class(self, key: str) -> type[BaseSystemAgent] | None: + """Get a system agent class by key.""" + return self._agent_classes.get(key) + + def get_instance( + self, + key: str, + llm: "BaseChatModel | None" = None, + tools: list["BaseTool"] | None = None, + ) -> BaseSystemAgent | None: + """ + Get or create a system agent instance. + + If the instance doesn't exist, it will be created and configured + with the provided LLM and tools. + + Args: + key: System agent key + llm: LLM to configure the agent with + tools: Tools to make available to the agent + + Returns: + Configured agent instance or None if not found + """ + # Check if we have a cached, configured instance + cache_key = f"{key}:{id(llm)}:{id(tools)}" + if cache_key in self._instances: + return self._instances[cache_key] + + # Get agent class + agent_class = self._agent_classes.get(key) + if not agent_class: + logger.warning(f"System agent not found: {key}") + return None + + # Create and configure instance + try: + instance = agent_class() + instance.configure(llm=llm, tools=tools) + self._instances[cache_key] = instance + logger.debug(f"Created system agent instance: {key}") + return instance + except Exception as e: + logger.error(f"Failed to create system agent '{key}': {e}") + return None + + def list_keys(self) -> list[str]: + """List all registered system agent keys.""" + return list(self._agent_classes.keys()) + + def get_all_metadata(self) -> list[dict[str, Any]]: + """ + Get metadata for all registered system agents. + + Returns: + List of metadata dictionaries + """ + result = [] + for key, agent_class in self._agent_classes.items(): + try: + # Create temporary instance for metadata + instance = agent_class() + result.append( + { + "key": key, + "metadata": instance.get_metadata(), + "forkable": True, + "components": [c.metadata.model_dump() for c in instance.get_exported_components()], + } + ) + except Exception as e: + logger.warning(f"Failed to get metadata for '{key}': {e}") + result.append( + { + "key": key, + "metadata": {"name": key, "description": "Metadata unavailable"}, + "forkable": False, + "error": str(e), + } + ) + return result + + def initialize_components(self) -> None: + """ + Initialize and register components from all system agents. + + This should be called during application startup to ensure + all components are available in the component registry. + """ + if self._initialized: + return + + logger.info("Initializing system agent components...") + + for key, agent_class in self._agent_classes.items(): + try: + instance = agent_class() + instance.register_components() + logger.debug(f"Registered components for: {key}") + except Exception as e: + logger.error(f"Failed to register components for '{key}': {e}") + + self._initialized = True + logger.info(f"System agent components initialized ({len(self._agent_classes)} agents)") + + def discover_agents(self) -> None: + """ + Discover and register system agents from subdirectories. + + Scans the system directory for Python packages containing + agent implementations. + """ + current_dir = Path(__file__).parent + + # Look for subdirectories with __init__.py + for path in current_dir.iterdir(): + if path.is_dir() and not path.name.startswith("_"): + init_file = path / "__init__.py" + if init_file.exists(): + self._try_import_agent(path.name) + + def _try_import_agent(self, module_name: str) -> None: + """Attempt to import and register an agent from a module.""" + try: + # Import the module + module = importlib.import_module(f".{module_name}", package=__package__) + + # Look for BaseSystemAgent subclasses + for attr_name in dir(module): + attr = getattr(module, attr_name) + + # Check if it's a class that inherits from BaseSystemAgent + if ( + isinstance(attr, type) + and issubclass(attr, BaseSystemAgent) + and attr is not BaseSystemAgent + and hasattr(attr, "SYSTEM_KEY") + and attr.SYSTEM_KEY + ): + self.register_class(attr.SYSTEM_KEY, attr, override=True) + logger.info(f"Discovered system agent: {attr.SYSTEM_KEY} from {module_name}") + + except Exception as e: + logger.warning(f"Failed to import system agent from {module_name}: {e}") + + +# Global registry instance +system_agent_registry = SystemAgentRegistry() + + +def _initialize_registry() -> None: + """Initialize the registry on module import.""" + system_agent_registry.discover_agents() + + +# Run discovery on import +_initialize_registry() + + +# Export +__all__ = [ + "BaseSystemAgent", + "SystemAgentRegistry", + "system_agent_registry", +] diff --git a/service/app/agents/system/base.py b/service/app/agents/system/base.py new file mode 100644 index 00000000..eca02e0a --- /dev/null +++ b/service/app/agents/system/base.py @@ -0,0 +1,209 @@ +""" +Base System Agent - Abstract base class for Python-coded system agents. + +System agents are complex, Python-implemented agents that: +- Provide sophisticated workflows (e.g., DeepResearch, CodeAnalysis) +- Export their graph config as JSON for user forking +- Register reusable components to the component registry +- Are available to all users as built-in capabilities +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar + +from app.agents.base_graph_agent import BaseBuiltinGraphAgent +from app.agents.components import BaseComponent + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + from langchain_core.tools import BaseTool + + from app.schemas.graph_config import GraphConfig + + +class BaseSystemAgent(BaseBuiltinGraphAgent): + """ + Abstract base class for Python-coded system agents. + + System agents are sophisticated, pre-built agents that provide complex + workflows beyond what simple JSON configuration can achieve. They: + + 1. Are implemented in Python for maximum flexibility + 2. Export their workflow as JSON for users to fork and customize + 3. Register reusable components (nodes, prompts, schemas) to the registry + 4. Are available as built-in capabilities to all users + + Subclasses must: + - Define SYSTEM_KEY class variable (e.g., "deep_research") + - Implement build_graph() to construct the workflow + - Implement export_graph_config() to export JSON representation + - Implement get_exported_components() to list reusable components + + Example: + class DeepResearchAgent(BaseSystemAgent): + SYSTEM_KEY = "deep_research" + + def build_graph(self): + # Build LangGraph workflow + ... + + def export_graph_config(self): + # Export as GraphConfig + ... + """ + + # Class-level system agent key (must be unique) + SYSTEM_KEY: ClassVar[str] = "" + + # Runtime dependencies (injected when instantiated) + llm: "BaseChatModel | None" = None + tools: list["BaseTool"] | None = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Auto-register system agents when subclass is defined.""" + super().__init_subclass__(**kwargs) + + # Only register if SYSTEM_KEY is defined and non-empty + if hasattr(cls, "SYSTEM_KEY") and cls.SYSTEM_KEY: + # Defer registration to avoid circular imports + # Registration happens in __init__.py + pass + + def __init__( + self, + name: str = "", + description: str = "", + version: str = "1.0.0", + capabilities: list[str] | None = None, + tags: list[str] | None = None, + author: str | None = "Xyzen", + license_: str | None = "MIT", + ) -> None: + """ + Initialize the system agent. + + Args: + name: Human-readable name + description: Description of agent capabilities + version: Semantic version string + capabilities: List of capability strings + tags: Tags for categorization + author: Agent author + license_: License identifier + """ + super().__init__( + name=name, + description=description, + version=version, + capabilities=capabilities, + tags=tags, + author=author, + license_=license_, + ) + + def configure( + self, + llm: "BaseChatModel | None" = None, + tools: list["BaseTool"] | None = None, + ) -> "BaseSystemAgent": + """ + Configure runtime dependencies. + + This method allows injecting LLM and tools at runtime, + enabling the same agent class to work with different providers. + + Args: + llm: LangChain LLM to use + tools: List of tools available to the agent + + Returns: + Self for method chaining + """ + self.llm = llm + self.tools = tools or [] + return self + + @abstractmethod + def export_graph_config(self) -> "GraphConfig": + """ + Export the agent's workflow as a JSON GraphConfig. + + This allows users to: + 1. View the agent's workflow structure + 2. Fork and customize the agent + 3. Learn from the implementation + + Returns: + GraphConfig representing this agent's workflow + """ + ... + + @abstractmethod + def get_exported_components(self) -> list[BaseComponent]: + """ + Return list of reusable components this agent provides. + + Components are automatically registered to the global registry + and can be used in user-defined agents. + + Returns: + List of components to export + """ + ... + + def get_forkable_config(self) -> dict[str, Any]: + """ + Get configuration for creating a user-customized fork. + + Returns a complete Agent creation payload that can be customized + by users and saved as their own agent. + + Returns: + Dictionary suitable for AgentCreate + """ + graph_config = self.export_graph_config() + graph_config_dict = graph_config.model_dump() + + # Add system_agent_key to metadata for fallback behavior + if "metadata" not in graph_config_dict: + graph_config_dict["metadata"] = {} + graph_config_dict["metadata"]["system_agent_key"] = self.SYSTEM_KEY + + return { + "name": f"{self.name} (Custom)", + "description": self.description, + "graph_config": graph_config_dict, + "tags": list(self.tags) + ["forked", f"from:{self.SYSTEM_KEY}"], + "model": None, # Inherit from user settings + "prompt": None, # Uses graph_config prompts + } + + def register_components(self) -> None: + """ + Register this agent's components to the global registry. + + Called automatically during system initialization. + """ + from app.agents.components import component_registry + + for component in self.get_exported_components(): + try: + component_registry.register(component, override=True) + except Exception as e: + import logging + + logging.getLogger(__name__).warning(f"Failed to register component {component.metadata.key}: {e}") + + def get_system_key(self) -> str: + """Get the system key for this agent.""" + return self.SYSTEM_KEY + + def is_configured(self) -> bool: + """Check if runtime dependencies are configured.""" + return self.llm is not None + + +# Export +__all__ = ["BaseSystemAgent"] diff --git a/service/app/agents/system/deep_research/__init__.py b/service/app/agents/system/deep_research/__init__.py new file mode 100644 index 00000000..9456dce4 --- /dev/null +++ b/service/app/agents/system/deep_research/__init__.py @@ -0,0 +1,54 @@ +""" +Deep Research System Agent + +A research system that provides: +- User clarification (optional) +- Research brief generation +- Tool-enabled research via LLM +- Comprehensive report synthesis with citations + +Usage: + from app.agents.system.deep_research import DeepResearchAgent + + agent = DeepResearchAgent() + agent.configure(llm=my_llm, tools=[search_tool]) + graph = agent.build_graph() + + result = await graph.ainvoke({"messages": [HumanMessage(content="Research topic")]}) + + # Customize configuration + from app.agents.system.deep_research import DeepResearchConfig + config = DeepResearchConfig(allow_clarification=False, max_researcher_iterations=4) + agent = DeepResearchAgent(config=config) +""" + +from app.agents.system.deep_research.agent import DeepResearchAgent +from app.agents.system.deep_research.components import ( + QueryAnalyzerComponent, + ResearchSupervisorComponent, + SynthesisComponent, +) +from app.agents.system.deep_research.configuration import DEFAULT_CONFIG, DeepResearchConfig +from app.agents.system.deep_research.graph_config import ( + DEFAULT_GRAPH_CONFIG, + create_deep_research_graph_config, + create_state_schema, + get_default_prompts, +) + +__all__ = [ + # Main agent + "DeepResearchAgent", + # Configuration + "DeepResearchConfig", + "DEFAULT_CONFIG", + # Graph configuration (JSON-configurable) + "create_deep_research_graph_config", + "create_state_schema", + "get_default_prompts", + "DEFAULT_GRAPH_CONFIG", + # Components + "QueryAnalyzerComponent", + "SynthesisComponent", + "ResearchSupervisorComponent", +] diff --git a/service/app/agents/system/deep_research/agent.py b/service/app/agents/system/deep_research/agent.py new file mode 100644 index 00000000..afa9123a --- /dev/null +++ b/service/app/agents/system/deep_research/agent.py @@ -0,0 +1,200 @@ +""" +Deep Research System Agent + +A sophisticated multi-agent research system that: +1. Clarifies research scope with the user (optional) +2. Generates a structured research brief +3. Conducts research via LLM with tools +4. Synthesizes findings into a comprehensive report with citations + +This agent uses GraphBuilder for JSON-configurable workflows. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from app.agents.components import BaseComponent +from app.agents.system.base import BaseSystemAgent +from app.agents.system.deep_research.configuration import DEFAULT_CONFIG, DeepResearchConfig +from app.agents.system.deep_research.graph_config import ( + create_deep_research_graph_config, + create_state_schema, +) +from app.agents.system.deep_research.utils import get_research_tools +from app.agents.types import DynamicCompiledGraph +from app.schemas.graph_config import GraphConfig + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +class DeepResearchAgent(BaseSystemAgent): + """ + Deep Research Agent - Research workflow with GraphBuilder. + + This agent implements a research workflow: + + Main Flow: + 1. clarify_with_user - Ask clarifying questions if needed + 2. write_research_brief - Generate structured research brief + 3. research_supervisor - Coordinate research with tools + 4. final_report_generation - Synthesize comprehensive report + + Features: + - JSON-configurable workflow via GraphConfig + - Customizable prompts via prompt_templates + - Tool-enabled research via LLM + - Citation-rich final reports + + Note: + Advanced features like parallel research and supervisor loops + require future GraphBuilder enhancements. See metadata.pending_features + in the exported GraphConfig for details. + """ + + SYSTEM_KEY = "deep_research" + + def __init__(self, config: DeepResearchConfig | None = None) -> None: + super().__init__( + name="Deep Research", + description=( + "Research system that generates comprehensive reports with citations. " + "Uses structured workflow: clarification → research brief → " + "research execution → final report synthesis." + ), + version="3.0.0", + capabilities=[ + "deep-research", + "citations", + "structured-workflow", + "tool-enabled", + "json-configurable", + ], + tags=["research", "multi-phase", "citations", "configurable"], + author="Xyzen", + ) + self.config = config or DEFAULT_CONFIG + self._main_graph: DynamicCompiledGraph | None = None + self._graph_config: GraphConfig | None = None + + def configure( + self, + llm: "BaseChatModel | None" = None, + tools: list["BaseTool"] | None = None, + ) -> "DeepResearchAgent": + """ + Configure runtime dependencies and build the graph. + + Args: + llm: LangChain LLM to use for all operations + tools: List of tools available to researchers (session tools) + + Returns: + Self for method chaining + """ + super().configure(llm, tools) + + if llm is not None: + self._build_graph() + + return self + + def _build_graph(self) -> None: + """Build graph using GraphBuilder from JSON config.""" + from app.agents.graph_builder import GraphBuilder + + if not self.llm: + raise RuntimeError("LLM not configured. Call configure() first.") + + # Create the canonical graph config + self._graph_config = create_deep_research_graph_config( + max_concurrent_research_units=self.config.max_concurrent_research_units, + max_researcher_iterations=self.config.max_researcher_iterations, + allow_clarification=self.config.allow_clarification, + ) + + # Prepare tools + research_tools = get_research_tools(self.tools or []) + tool_registry = {t.name: t for t in research_tools} + + # Create LLM factory that returns the configured LLM + llm = self.llm # Capture for closure + + async def llm_factory(**kwargs: Any) -> "BaseChatModel": + # In the future, this could support model overrides from config + return llm + + # Build graph using GraphBuilder + builder = GraphBuilder( + config=self._graph_config, + llm_factory=llm_factory, + tool_registry=tool_registry, + ) + + self._main_graph = builder.build() + + logger.info(f"Built Deep Research agent using GraphBuilder (nodes={len(self._graph_config.nodes)})") + + def build_graph(self) -> DynamicCompiledGraph: + """Build and return the Deep Research graph. + + Returns: + Compiled StateGraph ready for execution + + Raises: + RuntimeError: If configure() hasn't been called + """ + if self._main_graph is None: + if not self.llm: + raise RuntimeError("LLM not configured. Call configure() first.") + self._build_graph() + + return self._main_graph # type: ignore[return-value] + + def get_state_schema(self) -> dict[str, Any]: + """Return the state schema for this agent.""" + schema = create_state_schema() + return {name: field.type for name, field in schema.fields.items()} + + def export_graph_config(self) -> GraphConfig: + """Export as JSON GraphConfig for forking. + + This returns the canonical GraphConfig that can be used to: + 1. Reconstruct the agent workflow via GraphBuilder + 2. Fork and customize the agent configuration + 3. Visualize the workflow structure + + Note: The exported config includes TODO markers for features that + require GraphBuilder enhancement (parallel execution, loops). + """ + if self._graph_config: + return self._graph_config + + return create_deep_research_graph_config( + max_concurrent_research_units=self.config.max_concurrent_research_units, + max_researcher_iterations=self.config.max_researcher_iterations, + allow_clarification=self.config.allow_clarification, + ) + + def get_exported_components(self) -> list[BaseComponent]: + """Return reusable components from this agent.""" + from app.agents.system.deep_research.components import ( + QueryAnalyzerComponent, + ResearchSupervisorComponent, + SynthesisComponent, + ) + + return [ + QueryAnalyzerComponent(), + SynthesisComponent(), + ResearchSupervisorComponent(), + ] + + +# Export +__all__ = ["DeepResearchAgent", "DeepResearchConfig"] diff --git a/service/app/agents/system/deep_research/components.py b/service/app/agents/system/deep_research/components.py new file mode 100644 index 00000000..ceea3e01 --- /dev/null +++ b/service/app/agents/system/deep_research/components.py @@ -0,0 +1,373 @@ +""" +Deep Research Components - Reusable components exported by the DeepResearch agent. + +These components can be used in user-defined agents or other system agents. +""" + +from __future__ import annotations + +from typing import Any + +from app.agents.components import ComponentMetadata, ComponentType, NodeComponent + + +class QueryAnalyzerComponent(NodeComponent): + """ + Query Analyzer Component + + Analyzes research queries to extract: + - Main topics + - Named entities + - Effective search terms + - Query type classification + """ + + @property + def metadata(self) -> ComponentMetadata: + return ComponentMetadata( + key="system:deep_research:query_analyzer", + name="Query Analyzer", + description="Analyzes research queries to extract topics, entities, and search terms", + component_type=ComponentType.NODE, + version="1.0.0", + author="Xyzen", + tags=["research", "analysis", "nlp", "query-understanding"], + input_schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The research query to analyze", + } + }, + "required": ["query"], + }, + output_schema={ + "type": "object", + "properties": { + "topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Main topics extracted from the query", + }, + "entities": { + "type": "array", + "items": {"type": "string"}, + "description": "Named entities found in the query", + }, + "search_terms": { + "type": "array", + "items": {"type": "string"}, + "description": "Generated search terms", + }, + "query_type": { + "type": "string", + "enum": ["factual", "exploratory", "comparative", "analytical"], + "description": "Classification of the query type", + }, + }, + }, + required_tools=[], + required_components=[], + ) + + def export_config(self) -> dict[str, Any]: + """Export as node configuration.""" + return { + "id": "query_analyzer", + "name": "Query Analyzer", + "type": "llm", + "description": self.metadata.description, + "llm_config": { + "prompt_template": self.PROMPT_TEMPLATE, + "output_key": "query_analysis", + "temperature_override": 0.3, + "tools_enabled": False, + }, + "tags": self.metadata.tags, + } + + def get_example_usage(self) -> str: + return """ +# Using the Query Analyzer component in a custom agent + +# 1. Import the component +from app.agents.components import get_component_config + +# 2. Get the component configuration +analyzer_config = get_component_config("system:deep_research:query_analyzer") + +# 3. Add to your graph config +graph_config = { + "nodes": [ + analyzer_config, # Add the query analyzer node + # ... your other nodes + ], + "edges": [ + {"from_node": "START", "to_node": "query_analyzer"}, + # ... your other edges + ] +} +""" + + PROMPT_TEMPLATE = """Analyze the following research query and extract structured information. + +Query: {{ state.query }} + +Please provide your analysis in the following format: + +## Topics +List 2-5 main topics that this query is about. + +## Entities +List any named entities (people, organizations, places, technical terms). + +## Search Terms +Generate 3-5 effective search queries to find information about this topic. + +## Query Type +Classify this query as one of: +- factual: Looking for specific facts or data +- exploratory: Open-ended exploration of a topic +- comparative: Comparing multiple things +- analytical: Deep analysis or interpretation needed + +Provide your structured analysis:""" + + +class SynthesisComponent(NodeComponent): + """ + Synthesis Component + + Synthesizes research results into a comprehensive answer with: + - Clear structure + - Key findings + - Proper citations + """ + + @property + def metadata(self) -> ComponentMetadata: + return ComponentMetadata( + key="system:deep_research:synthesis", + name="Research Synthesis", + description="Synthesizes search results into comprehensive answers with citations", + component_type=ComponentType.NODE, + version="1.0.0", + author="Xyzen", + tags=["research", "synthesis", "citations", "summarization"], + input_schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The original research query", + }, + "search_results": { + "type": "array", + "description": "List of search results to synthesize", + }, + }, + "required": ["query", "search_results"], + }, + output_schema={ + "type": "object", + "properties": { + "synthesis": { + "type": "string", + "description": "The synthesized answer", + }, + "citations": { + "type": "array", + "description": "List of citations used", + }, + }, + }, + required_tools=[], + required_components=[], + ) + + def export_config(self) -> dict[str, Any]: + """Export as node configuration.""" + return { + "id": "synthesis", + "name": "Research Synthesis", + "type": "llm", + "description": self.metadata.description, + "llm_config": { + "prompt_template": self.PROMPT_TEMPLATE, + "output_key": "synthesis", + "temperature_override": 0.5, + "tools_enabled": False, + }, + "tags": self.metadata.tags, + } + + def get_example_usage(self) -> str: + return """ +# Using the Synthesis component in a custom agent + +# 1. Import the component +from app.agents.components import get_component_config + +# 2. Get the component configuration +synthesis_config = get_component_config("system:deep_research:synthesis") + +# 3. Add to your graph config after search nodes +graph_config = { + "nodes": [ + # ... search nodes + synthesis_config, # Add the synthesis node + ], + "edges": [ + # ... other edges + {"from_node": "search", "to_node": "synthesis"}, + {"from_node": "synthesis", "to_node": "END"}, + ] +} +""" + + PROMPT_TEMPLATE = """Based on the following research results, synthesize a comprehensive answer. + +## Original Query +{{ state.query }} + +## Research Results +{% for result in state.search_results %} +### Source {{ loop.index }} +{{ result.content | default(result | string) }} + +{% endfor %} + +## Instructions +Please synthesize the information above into a comprehensive answer that: + +1. **Directly addresses the query** - Start with the most relevant information +2. **Organizes logically** - Use clear structure and headings if helpful +3. **Cites sources** - Reference sources using [Source N] notation +4. **Acknowledges limitations** - Note any gaps or uncertainties + +## Your Synthesis:""" + + +class ResearchSupervisorComponent(NodeComponent): + """ + Research Supervisor Component + + Coordinates parallel research tasks by: + - Analyzing the research brief + - Delegating to sub-researchers + - Tracking research progress + - Deciding when to complete + """ + + @property + def metadata(self) -> ComponentMetadata: + return ComponentMetadata( + key="system:deep_research:supervisor", + name="Research Supervisor", + description="Coordinates parallel research by delegating to sub-researchers", + component_type=ComponentType.NODE, + version="2.0.0", + author="Xyzen", + tags=["research", "supervisor", "delegation", "parallel", "coordination"], + input_schema={ + "type": "object", + "properties": { + "research_brief": { + "type": "string", + "description": "The research brief to investigate", + }, + "max_concurrent_units": { + "type": "integer", + "description": "Maximum parallel research units", + "default": 5, + }, + }, + "required": ["research_brief"], + }, + output_schema={ + "type": "object", + "properties": { + "notes": { + "type": "array", + "items": {"type": "string"}, + "description": "Collected research notes from sub-agents", + }, + "research_complete": { + "type": "boolean", + "description": "Whether research is complete", + }, + }, + }, + required_tools=["ConductResearch", "ResearchComplete", "think_tool"], + required_components=[], + ) + + def export_config(self) -> dict[str, Any]: + """Export as node configuration.""" + return { + "id": "research_supervisor", + "name": "Research Supervisor", + "type": "llm", + "description": self.metadata.description, + "llm_config": { + "prompt_template": self.PROMPT_TEMPLATE, + "output_key": "notes", + "tools_enabled": True, + }, + "tags": self.metadata.tags, + } + + def get_example_usage(self) -> str: + return """ +# Using the Research Supervisor component in a custom agent + +# 1. Import the component +from app.agents.components import get_component_config + +# 2. Get the component configuration +supervisor_config = get_component_config("system:deep_research:supervisor") + +# 3. Add to your graph config +graph_config = { + "nodes": [ + # ... brief generation node + supervisor_config, # Add the supervisor node + # ... synthesis node + ], + "edges": [ + {"from_node": "write_brief", "to_node": "research_supervisor"}, + {"from_node": "research_supervisor", "to_node": "synthesis"}, + ] +} + +# Note: The supervisor uses ConductResearch tool to spawn sub-researchers +# and think_tool for strategic planning. +""" + + PROMPT_TEMPLATE = """You are a research supervisor coordinating research on the following brief: + +{{ state.research_brief }} + +Your available tools: +1. **ConductResearch** - Delegate research to a sub-agent with a detailed topic description +2. **ResearchComplete** - Signal that you have gathered enough information +3. **think_tool** - Reflect on progress and plan next steps + +Guidelines: +- Use think_tool before delegating to plan your approach +- Delegate clear, specific, non-overlapping topics +- Use parallel delegation for independent subtopics +- Stop when you have comprehensive coverage +- Maximum {{ state.max_concurrent_units | default(5) }} parallel units + +Current research notes: +{% for note in state.notes %} +{{ note }} +{% endfor %} + +Decide your next action:""" + + +# Export +__all__ = ["QueryAnalyzerComponent", "SynthesisComponent", "ResearchSupervisorComponent"] diff --git a/service/app/agents/system/deep_research/configuration.py b/service/app/agents/system/deep_research/configuration.py new file mode 100644 index 00000000..201f4a4a --- /dev/null +++ b/service/app/agents/system/deep_research/configuration.py @@ -0,0 +1,53 @@ +"""Configuration for the Deep Research agent.""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + + +class DeepResearchConfig(BaseModel): + """Configuration for the Deep Research agent. + + This configuration controls the research workflow behavior. + LLM and tools are injected via the agent's configure() method, + so no model-specific configuration is needed here. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Research flow controls + max_concurrent_research_units: int = Field( + default=5, + ge=1, + le=20, + description="Maximum number of research units to run concurrently", + ) + max_researcher_iterations: int = Field( + default=6, + ge=1, + le=10, + description="Maximum research iterations for the supervisor", + ) + max_react_tool_calls: int = Field( + default=10, + ge=1, + le=30, + description="Maximum tool calls per researcher", + ) + allow_clarification: bool = Field( + default=True, + description="Whether to ask clarifying questions before research", + ) + max_structured_output_retries: int = Field( + default=3, + ge=1, + le=10, + description="Maximum retries for structured output failures", + ) + + +# Default configuration instance +DEFAULT_CONFIG = DeepResearchConfig() + + +__all__ = ["DeepResearchConfig", "DEFAULT_CONFIG"] diff --git a/service/app/agents/system/deep_research/graph_config.py b/service/app/agents/system/deep_research/graph_config.py new file mode 100644 index 00000000..02846c08 --- /dev/null +++ b/service/app/agents/system/deep_research/graph_config.py @@ -0,0 +1,380 @@ +""" +Canonical GraphConfig definition for the Deep Research agent. + +This module defines the complete JSON-serializable configuration that can +fully reconstruct the Deep Research workflow. It serves as the single source +of truth for the agent's structure. + +NOTE: The current GraphBuilder does not yet support: +- Parallel execution nodes (for concurrent researchers) +- Loop constructs (for supervisor iteration) +- Dynamic subagent spawning + +These features are planned for a future GraphBuilder enhancement. +Until then, `use_graph_builder=True` will use a simplified workflow, +while `use_graph_builder=False` uses the full Python-coded implementation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from app.schemas.graph_config import ( + ConditionOperator, + EdgeCondition, + GraphConfig, + GraphEdgeConfig, + GraphNodeConfig, + GraphStateSchema, + LLMNodeConfig, + NodeType, + ReducerType, + StateFieldSchema, +) + +if TYPE_CHECKING: + pass + + +def get_default_prompts() -> dict[str, str]: + """Return the default prompt templates for Deep Research. + + These can be overridden when forking the agent. + """ + from app.agents.system.deep_research.prompts import ( + CLARIFY_WITH_USER_PROMPT, + COMPRESS_RESEARCH_HUMAN_MESSAGE, + COMPRESS_RESEARCH_SYSTEM_PROMPT, + FINAL_REPORT_PROMPT, + LEAD_RESEARCHER_PROMPT, + RESEARCH_BRIEF_PROMPT, + RESEARCHER_PROMPT, + ) + + return { + "clarify_with_user": CLARIFY_WITH_USER_PROMPT, + "research_brief": RESEARCH_BRIEF_PROMPT, + "lead_researcher": LEAD_RESEARCHER_PROMPT, + "researcher": RESEARCHER_PROMPT, + "compress_research_system": COMPRESS_RESEARCH_SYSTEM_PROMPT, + "compress_research_human": COMPRESS_RESEARCH_HUMAN_MESSAGE, + "final_report": FINAL_REPORT_PROMPT, + } + + +def create_state_schema() -> GraphStateSchema: + """Create the state schema for Deep Research workflow.""" + return GraphStateSchema( + fields={ + "messages": StateFieldSchema( + type="messages", + default=[], + reducer=ReducerType.MESSAGES, + description="Conversation messages with the user", + ), + "research_brief": StateFieldSchema( + type="string", + default="", + description="Generated research brief that guides the research", + ), + "supervisor_messages": StateFieldSchema( + type="list", + default=[], + reducer=ReducerType.APPEND, + description="Messages for the research supervisor context", + ), + "notes": StateFieldSchema( + type="list", + default=[], + reducer=ReducerType.APPEND, + description="Compressed research notes from sub-agents", + ), + "raw_notes": StateFieldSchema( + type="list", + default=[], + reducer=ReducerType.APPEND, + description="Raw research findings before compression", + ), + "final_report": StateFieldSchema( + type="string", + default="", + description="Final synthesized research report", + ), + "need_clarification": StateFieldSchema( + type="bool", + default=False, + description="Whether user clarification is needed", + ), + "skip_research": StateFieldSchema( + type="bool", + default=False, + description="Whether to skip research (for follow-up requests like translation)", + ), + "research_iterations": StateFieldSchema( + type="int", + default=0, + description="Current iteration count in research loop", + ), + } + ) + + +def create_deep_research_graph_config( + max_concurrent_research_units: int = 5, + max_researcher_iterations: int = 6, + allow_clarification: bool = True, +) -> GraphConfig: + """ + Create the canonical GraphConfig for Deep Research. + + This configuration represents the complete workflow structure. + Note: Some advanced features (parallel execution, loops) require + GraphBuilder enhancements that are not yet implemented. + + Args: + max_concurrent_research_units: Max parallel researchers + max_researcher_iterations: Max supervisor iterations + allow_clarification: Whether to include clarification phase + + Returns: + Complete GraphConfig for the Deep Research agent + """ + nodes: list[GraphNodeConfig] = [] + edges: list[GraphEdgeConfig] = [] + + # --- Node Definitions --- + + # 1. Clarify with User (optional) + if allow_clarification: + from app.schemas.graph_config import StructuredOutputField, StructuredOutputSchema + + # Define structured output schema for clarification + # Supports three modes: follow_up (skip research), new_research, needs_clarification + clarification_schema = StructuredOutputSchema( + description="Response for user clarification decision", + fields={ + "request_type": StructuredOutputField( + type="string", + description="Type of request: 'follow_up' (no research needed), 'new_research', or 'needs_clarification'", + required=True, + default="new_research", + ), + "need_clarification": StructuredOutputField( + type="bool", + description="Whether the user needs to be asked a clarifying question", + required=True, + ), + "skip_research": StructuredOutputField( + type="bool", + description="Whether to skip research (true for follow-up requests like translation)", + required=True, + default=False, + ), + "question": StructuredOutputField( + type="string", + description="A question to ask the user to clarify the report scope (if need_clarification is true)", + default="", + required=False, + ), + "verification": StructuredOutputField( + type="string", + description="Verification message OR complete response for follow-up requests", + default="", + required=False, + ), + }, + ) + + nodes.append( + GraphNodeConfig( + id="clarify_with_user", + name="Clarify with User", + type=NodeType.LLM, + description="Analyze query and ask clarifying questions if needed", + llm_config=LLMNodeConfig( + prompt_template="{{ prompt_templates.clarify_with_user }}", + output_key="clarification_response", + temperature_override=0.3, + tools_enabled=False, + # Structured output parsing + structured_output=clarification_schema, + # Conditional message selection: + # - If need_clarification is True: show 'question' + # - If need_clarification is False: show 'verification' + message_key_condition={ + "condition_field": "need_clarification", + "true_key": "question", + "false_key": "verification", + }, + ), + tags=["clarification", "user-interaction"], + ) + ) + + # 2. Write Research Brief + nodes.append( + GraphNodeConfig( + id="write_research_brief", + name="Write Research Brief", + type=NodeType.LLM, + description="Transform user messages into a structured research brief", + llm_config=LLMNodeConfig( + prompt_template="{{ prompt_templates.research_brief }}", + output_key="research_brief", + temperature_override=0.5, + tools_enabled=False, + ), + tags=["planning", "brief"], + ) + ) + + # 3. Research Supervisor + # TODO: This node needs parallel execution support in GraphBuilder + # Currently, the Python implementation handles the supervisor loop + # and parallel researcher spawning. This is a placeholder for future + # GraphBuilder enhancement. + nodes.append( + GraphNodeConfig( + id="research_supervisor", + name="Research Supervisor", + type=NodeType.LLM, + description=( + "Coordinate research by delegating to sub-researchers. " + "NOTE: Full parallel execution requires GraphBuilder enhancement." + ), + llm_config=LLMNodeConfig( + prompt_template="{{ prompt_templates.lead_researcher }}", + output_key="supervisor_response", + tools_enabled=True, + max_iterations=max_researcher_iterations, + ), + tags=["supervisor", "delegation", "parallel"], + ) + ) + + # 4. Final Report Generation + nodes.append( + GraphNodeConfig( + id="final_report_generation", + name="Final Report Generation", + type=NodeType.LLM, + description="Synthesize all research findings into a comprehensive report", + llm_config=LLMNodeConfig( + prompt_template="{{ prompt_templates.final_report }}", + output_key="final_report", + tools_enabled=False, + ), + tags=["synthesis", "report"], + ) + ) + + # --- Edge Definitions --- + + if allow_clarification: + # START -> clarify_with_user + edges.append(GraphEdgeConfig(from_node="START", to_node="clarify_with_user")) + + # clarify_with_user -> END (if needs clarification) + edges.append( + GraphEdgeConfig( + from_node="clarify_with_user", + to_node="END", + condition=EdgeCondition( + state_key="need_clarification", + operator=ConditionOperator.TRUTHY, + value=None, + target="END", + ), + label="Ask clarifying question", + priority=2, + ) + ) + + # clarify_with_user -> END (if skip_research for follow-up requests) + edges.append( + GraphEdgeConfig( + from_node="clarify_with_user", + to_node="END", + condition=EdgeCondition( + state_key="skip_research", + operator=ConditionOperator.TRUTHY, + value=None, + target="END", + ), + label="Handle follow-up request directly", + priority=1, + ) + ) + + # clarify_with_user -> write_research_brief (if no clarification needed and not skipping research) + edges.append( + GraphEdgeConfig( + from_node="clarify_with_user", + to_node="write_research_brief", + condition=EdgeCondition( + state_key="need_clarification", + operator=ConditionOperator.FALSY, + value=None, + target="write_research_brief", + ), + label="Proceed to research", + priority=0, + ) + ) + else: + # START -> write_research_brief (skip clarification) + edges.append(GraphEdgeConfig(from_node="START", to_node="write_research_brief")) + + # write_research_brief -> research_supervisor + edges.append(GraphEdgeConfig(from_node="write_research_brief", to_node="research_supervisor")) + + # research_supervisor -> final_report_generation + edges.append(GraphEdgeConfig(from_node="research_supervisor", to_node="final_report_generation")) + + # final_report_generation -> END + edges.append(GraphEdgeConfig(from_node="final_report_generation", to_node="END")) + + # --- Build Complete Config --- + + entry_point = "clarify_with_user" if allow_clarification else "write_research_brief" + + return GraphConfig( + version="2.0", + state_schema=create_state_schema(), + nodes=nodes, + edges=edges, + entry_point=entry_point, + prompt_templates=get_default_prompts(), + metadata={ + "author": "Xyzen", + "version": "2.0.0", + "description": "Deep Research agent with supervisor pattern and parallel research", + "system_agent_key": "deep_research", + # Configuration parameters stored in metadata for reference + "config": { + "max_concurrent_research_units": max_concurrent_research_units, + "max_researcher_iterations": max_researcher_iterations, + "allow_clarification": allow_clarification, + }, + # TODO markers for future GraphBuilder enhancements + "pending_features": [ + "parallel_execution: research_supervisor needs to spawn parallel researchers", + "loop_construct: supervisor should iterate until ResearchComplete called", + "dynamic_spawning: spawn N researchers based on ConductResearch tool calls", + ], + }, + max_execution_time_seconds=600, # Deep research can take longer + enable_checkpoints=True, + ) + + +# Pre-built default config +DEFAULT_GRAPH_CONFIG = create_deep_research_graph_config() + + +__all__ = [ + "create_deep_research_graph_config", + "create_state_schema", + "get_default_prompts", + "DEFAULT_GRAPH_CONFIG", +] diff --git a/service/app/agents/system/deep_research/prompts.py b/service/app/agents/system/deep_research/prompts.py new file mode 100644 index 00000000..f8078c50 --- /dev/null +++ b/service/app/agents/system/deep_research/prompts.py @@ -0,0 +1,343 @@ +"""System prompts and prompt templates for the Deep Research agent.""" + +from __future__ import annotations + +# Clarification prompt - asks clarifying questions if research scope is unclear +CLARIFY_WITH_USER_PROMPT = """ +These are the messages that have been exchanged so far from the user asking for the report: + +{messages} + + +Today's date is {date}. + +First, analyze the conversation to determine what type of request this is: + +1. **Follow-up Request on Existing Content**: If the user's latest message is asking to perform an operation on content that was ALREADY generated in this conversation (e.g., translate, summarize, reformat, extract specific parts, explain further, etc.), this does NOT require new research. + - Examples: "Translate this to Chinese", "Summarize the key points", "Make it shorter", "Explain the third section in more detail", "Convert to bullet points" + - In this case, you should handle it directly without starting a new research cycle. + +2. **New Research Request**: If the user is asking a new question or wants to research a new topic that requires gathering information from external sources. + - Examples: "Research the latest AI developments", "What are the best restaurants in Tokyo?", "Compare iPhone vs Android" + +3. **Clarification Needed**: If the research request is unclear and you need more information from the user. + +Respond in valid JSON format with these exact keys: +"request_type": "follow_up" | "new_research" | "needs_clarification", +"need_clarification": boolean, +"question": "", +"verification": "", +"skip_research": boolean + +**For follow-up requests** (translation, summarization, reformatting, etc.): +- Set "request_type": "follow_up" +- Set "need_clarification": false +- Set "skip_research": true +- In "verification", provide the COMPLETE response to the user's request (e.g., the full translation, summary, etc.) +- This is important: you must actually perform the requested operation (translate, summarize, etc.) and put the full result in "verification" + +**For new research requests**: +- Set "request_type": "new_research" +- Set "need_clarification": false +- Set "skip_research": false +- In "verification", acknowledge that you will start research + +**For requests needing clarification**: +- Set "request_type": "needs_clarification" +- Set "need_clarification": true +- Set "skip_research": false +- In "question", provide your clarifying question + +Guidelines for clarification questions: +- Be concise while gathering all necessary information +- Make sure to gather all the information needed to carry out the research task +- Use bullet points or numbered lists if appropriate for clarity +- Don't ask for unnecessary information, or information that the user has already provided +- IMPORTANT: If you have already asked a clarifying question in the message history, you almost always do not need to ask another one + +Guidelines for follow-up responses: +- Actually perform the requested operation completely +- If asked to translate, provide the FULL translation +- If asked to summarize, provide a complete summary +- Match the quality and depth of the original content +""" + +# Research brief generation - transforms user messages into a research brief +RESEARCH_BRIEF_PROMPT = """You will be given a set of messages that have been exchanged so far between yourself and the user. +Your job is to translate these messages into a more detailed and concrete research question that will be used to guide the research. + +The messages that have been exchanged so far between yourself and the user are: + +{messages} + + +Today's date is {date}. + +You will return a single research question that will be used to guide the research. + +Guidelines: +1. Maximize Specificity and Detail +- Include all known user preferences and explicitly list key attributes or dimensions to consider. +- It is important that all details from the user are included in the instructions. + +2. Fill in Unstated But Necessary Dimensions as Open-Ended +- If certain attributes are essential for a meaningful output but the user has not provided them, explicitly state that they are open-ended or default to no specific constraint. + +3. Avoid Unwarranted Assumptions +- If the user has not provided a particular detail, do not invent one. +- Instead, state the lack of specification and guide the researcher to treat it as flexible or accept all possible options. + +4. Use the First Person +- Phrase the request from the perspective of the user. + +5. Sources +- If specific sources should be prioritized, specify them in the research question. +- For product and travel research, prefer linking directly to official or primary websites (e.g., official brand sites, manufacturer pages, or reputable e-commerce platforms like Amazon for user reviews) rather than aggregator sites or SEO-heavy blogs. +- For academic or scientific queries, prefer linking directly to the original paper or official journal publication rather than survey papers or secondary summaries. +- For people, try linking directly to their LinkedIn profile, or their personal website if they have one. +- If the query is in a specific language, prioritize sources published in that language. +""" + +# Lead researcher/supervisor prompt +LEAD_RESEARCHER_PROMPT = """You are a research supervisor. Your job is to conduct research by calling the "ConductResearch" tool. For context, today's date is {date}. + + +Your focus is to call the "ConductResearch" tool to conduct research against the overall research question passed in by the user. +When you are completely satisfied with the research findings returned from the tool calls, then you should call the "ResearchComplete" tool to indicate that you are done with your research. + + + +You have access to three main tools: +1. **ConductResearch**: Delegate research tasks to specialized sub-agents +2. **ResearchComplete**: Indicate that research is complete +3. **think_tool**: For reflection and strategic planning during research + +**CRITICAL: Use think_tool before calling ConductResearch to plan your approach, and after each ConductResearch to assess progress. Do not call think_tool with any other tools in parallel.** + + + +Think like a research manager with limited time and resources. Follow these steps: + +1. **Read the question carefully** - What specific information does the user need? +2. **Decide how to delegate the research** - Carefully consider the question and decide how to delegate the research. Are there multiple independent directions that can be explored simultaneously? +3. **After each call to ConductResearch, pause and assess** - Do I have enough to answer? What's still missing? + + + +**Task Delegation Budgets** (Prevent excessive delegation): +- **Bias towards single agent** - Use single agent for simplicity unless the user request has clear opportunity for parallelization +- **Stop when you can answer confidently** - Don't keep delegating research for perfection +- **Limit tool calls** - Always stop after {max_researcher_iterations} tool calls to ConductResearch and think_tool if you cannot find the right sources + +**Maximum {max_concurrent_research_units} parallel agents per iteration** + + + +Before you call ConductResearch tool call, use think_tool to plan your approach: +- Can the task be broken down into smaller sub-tasks? + +After each ConductResearch tool call, use think_tool to analyze the results: +- What key information did I find? +- What's missing? +- Do I have enough to answer the question comprehensively? +- Should I delegate more research or call ResearchComplete? + + + +**Simple fact-finding, lists, and rankings** can use a single sub-agent: +- *Example*: List the top 10 coffee shops in San Francisco → Use 1 sub-agent + +**Comparisons presented in the user request** can use a sub-agent for each element of the comparison: +- *Example*: Compare OpenAI vs. Anthropic vs. DeepMind approaches to AI safety → Use 3 sub-agents +- Delegate clear, distinct, non-overlapping subtopics + +**Important Reminders:** +- Each ConductResearch call spawns a dedicated research agent for that specific topic +- A separate agent will write the final report - you just need to gather information +- When calling ConductResearch, provide complete standalone instructions - sub-agents can't see other agents' work +- Do NOT use acronyms or abbreviations in your research questions, be very clear and specific +""" + +# Individual researcher prompt +RESEARCHER_PROMPT = """You are a research assistant conducting research on the user's input topic. For context, today's date is {date}. + + +Your job is to use tools to gather information about the user's input topic. +You can use any of the tools provided to you to find resources that can help answer the research question. You can call these tools in series or in parallel, your research is conducted in a tool-calling loop. + + + +You have access to search tools and a reflection tool: +1. **Search tools**: For conducting web searches to gather information +2. **think_tool**: For reflection and strategic planning during research + +**CRITICAL: Use think_tool after each search to reflect on results and plan next steps. Do not call think_tool with search tools in parallel. It should be used to reflect on the results of the search.** + + + +Think like a human researcher with limited time. Follow these steps: + +1. **Read the question carefully** - What specific information does the user need? +2. **Start with broader searches** - Use broad, comprehensive queries first +3. **After each search, pause and assess** - Do I have enough to answer? What's still missing? +4. **Execute narrower searches as you gather information** - Fill in the gaps +5. **Stop when you can answer confidently** - Don't keep searching for perfection + + + +**Tool Call Budgets** (Prevent excessive searching): +- **Simple queries**: Use 2-3 search tool calls maximum +- **Complex queries**: Use up to 5 search tool calls maximum +- **Always stop**: After 5 search tool calls if you cannot find the right sources + +**Stop Immediately When**: +- You can answer the user's question comprehensively +- You have 3+ relevant examples/sources for the question +- Your last 2 searches returned similar information + + + +After each search tool call, use think_tool to analyze the results: +- What key information did I find? +- What's missing? +- Do I have enough to answer the question comprehensively? +- Should I search more or provide my answer? + +""" + +# Compression prompt - compresses research findings +COMPRESS_RESEARCH_SYSTEM_PROMPT = """You are a research assistant that has conducted research on a topic by calling several tools and web searches. Your job is now to clean up the findings, but preserve all of the relevant statements and information that the researcher has gathered. For context, today's date is {date}. + + +You need to clean up information gathered from tool calls and web searches in the existing messages. +All relevant information should be repeated and rewritten verbatim, but in a cleaner format. +The purpose of this step is just to remove any obviously irrelevant or duplicative information. +For example, if three sources all say "X", you could say "These three sources all stated X". +Only these fully comprehensive cleaned findings are going to be returned to the user, so it's crucial that you don't lose any information from the raw messages. + + + +1. Your output findings should be fully comprehensive and include ALL of the information and sources that the researcher has gathered from tool calls and web searches. It is expected that you repeat key information verbatim. +2. This report can be as long as necessary to return ALL of the information that the researcher has gathered. +3. In your report, you should return inline citations for each source that the researcher found. +4. You should include a "Sources" section at the end of the report that lists all of the sources the researcher found with corresponding citations, cited against statements in the report. +5. Make sure to include ALL of the sources that the researcher gathered in the report, and how they were used to answer the question! +6. It's really important not to lose any sources. A later LLM will be used to merge this report with others, so having all of the sources is critical. + + + +The report should be structured like this: +**List of Queries and Tool Calls Made** +**Fully Comprehensive Findings** +**List of All Relevant Sources (with citations in the report)** + + + +- Assign each unique URL a single citation number in your text +- End with ### Sources that lists each source with corresponding numbers +- IMPORTANT: Number sources sequentially without gaps (1,2,3,4...) in the final list regardless of which sources you choose +- Example format: + [1] Source Title: URL + [2] Source Title: URL + + +Critical Reminder: It is extremely important that any information that is even remotely relevant to the user's research topic is preserved verbatim (e.g. don't rewrite it, don't summarize it, don't paraphrase it). +""" + +COMPRESS_RESEARCH_HUMAN_MESSAGE = """All above messages are about research conducted by an AI Researcher. Please clean up these findings. + +DO NOT summarize the information. I want the raw information returned, just in a cleaner format. Make sure all relevant information is preserved - you can rewrite findings verbatim.""" + +# Final report generation prompt +FINAL_REPORT_PROMPT = """Based on all the research conducted, create a comprehensive, well-structured answer to the overall research brief: + +{research_brief} + + +For more context, here is all of the messages so far. Focus on the research brief above, but consider these messages as well for more context. + +{messages} + +CRITICAL: Make sure the answer is written in the same language as the human messages! +For example, if the user's messages are in English, then MAKE SURE you write your response in English. If the user's messages are in Chinese, then MAKE SURE you write your entire response in Chinese. +This is critical. The user will only understand the answer if it is written in the same language as their input message. + +Today's date is {date}. + +Here are the findings from the research that you conducted: + +{findings} + + +Please create a detailed answer to the overall research brief that: +1. Is well-organized with proper headings (# for title, ## for sections, ### for subsections) +2. Includes specific facts and insights from the research +3. References relevant sources using [Title](URL) format +4. Provides a balanced, thorough analysis. Be as comprehensive as possible, and include all information that is relevant to the overall research question. People are using you for deep research and will expect detailed, comprehensive answers. +5. Includes a "Sources" section at the end with all referenced links + +You can structure your report in a number of different ways. Here are some examples: + +To answer a question that asks you to compare two things, you might structure your report like this: +1/ intro +2/ overview of topic A +3/ overview of topic B +4/ comparison between A and B +5/ conclusion + +To answer a question that asks you to return a list of things, you might only need a single section which is the entire list. +1/ list of things or table of things +Or, you could choose to make each item in the list a separate section in the report. When asked for lists, you don't need an introduction or conclusion. +1/ item 1 +2/ item 2 +3/ item 3 + +To answer a question that asks you to summarize a topic, give a report, or give an overview, you might structure your report like this: +1/ overview of topic +2/ concept 1 +3/ concept 2 +4/ concept 3 +5/ conclusion + +If you think you can answer the question with a single section, you can do that too! +1/ answer + +REMEMBER: Section is a VERY fluid and loose concept. You can structure your report however you think is best, including in ways that are not listed above! +Make sure that your sections are cohesive, and make sense for the reader. + +For each section of the report, do the following: +- Use simple, clear language +- Use ## for section title (Markdown format) for each section of the report +- Do NOT ever refer to yourself as the writer of the report. This should be a professional report without any self-referential language. +- Do not say what you are doing in the report. Just write the report without any commentary from yourself. +- Each section should be as long as necessary to deeply answer the question with the information you have gathered. It is expected that sections will be fairly long and verbose. You are writing a deep research report, and users will expect a thorough answer. +- Use bullet points to list out information when appropriate, but by default, write in paragraph form. + +REMEMBER: +The brief and research may be in English, but you need to translate this information to the right language when writing the final answer. +Make sure the final answer report is in the SAME language as the human messages in the message history. + +Format the report in clear markdown with proper structure and include source references where appropriate. + + +- Assign each unique URL a single citation number in your text +- End with ### Sources that lists each source with corresponding numbers +- IMPORTANT: Number sources sequentially without gaps (1,2,3,4...) in the final list regardless of which sources you choose +- Each source should be a separate line item in a list, so that in markdown it is rendered as a list. +- Example format: + [1] Source Title: URL + [2] Source Title: URL +- Citations are extremely important. Make sure to include these, and pay a lot of attention to getting these right. Users will often use these citations to look into more information. + +""" + + +__all__ = [ + "CLARIFY_WITH_USER_PROMPT", + "RESEARCH_BRIEF_PROMPT", + "LEAD_RESEARCHER_PROMPT", + "RESEARCHER_PROMPT", + "COMPRESS_RESEARCH_SYSTEM_PROMPT", + "COMPRESS_RESEARCH_HUMAN_MESSAGE", + "FINAL_REPORT_PROMPT", +] diff --git a/service/app/agents/system/deep_research/state.py b/service/app/agents/system/deep_research/state.py new file mode 100644 index 00000000..556806a1 --- /dev/null +++ b/service/app/agents/system/deep_research/state.py @@ -0,0 +1,83 @@ +"""Structured outputs and state models for the Deep Research agent. + +These models are used for: +- Tool definitions (ResearchComplete) +- LLM structured outputs (ClarifyWithUser, ResearchQuestion) +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +################### +# Structured Outputs +################### + + +class ConductResearch(BaseModel): + """Tool call to conduct research on a specific topic. + + Used by the supervisor to delegate research tasks to sub-researchers. + """ + + research_topic: str = Field( + description=( + "The topic to research. Should be a single topic, described in high detail (at least a paragraph)." + ), + ) + + +class ResearchComplete(BaseModel): + """Tool call to indicate that the research is complete. + + Used by researchers to signal completion. + """ + + pass + + +class ClarifyWithUser(BaseModel): + """Model for user clarification responses from LLM. + + Supports three request types: + - follow_up: User is asking about existing content (translate, summarize, etc.) + - new_research: User wants to research a new topic + - needs_clarification: User's request is unclear + """ + + request_type: str = Field( + default="new_research", + description="Type of request: 'follow_up' (no research needed), 'new_research', or 'needs_clarification'", + ) + need_clarification: bool = Field( + description="Whether the user needs to be asked a clarifying question.", + ) + skip_research: bool = Field( + default=False, + description="Whether to skip research (true for follow-up requests like translation)", + ) + question: str = Field( + default="", + description="A question to ask the user to clarify the report scope", + ) + verification: str = Field( + default="", + description="Verification message OR complete response for follow-up requests", + ) + + +class ResearchQuestion(BaseModel): + """Research question and brief for guiding research.""" + + research_brief: str = Field( + description="A research question that will be used to guide the research.", + ) + + +__all__ = [ + "ConductResearch", + "ResearchComplete", + "ClarifyWithUser", + "ResearchQuestion", +] diff --git a/service/app/agents/system/deep_research/utils.py b/service/app/agents/system/deep_research/utils.py new file mode 100644 index 00000000..c16164e8 --- /dev/null +++ b/service/app/agents/system/deep_research/utils.py @@ -0,0 +1,201 @@ +"""Utility functions for the Deep Research agent.""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from langchain_core.messages import AIMessage, BaseMessage, MessageLikeRepresentation, filter_messages +from langchain_core.tools import tool + +if TYPE_CHECKING: + from langchain_core.tools import BaseTool + +from app.agents.system.deep_research.state import ResearchComplete + + +################### +# Think Tool +################### + + +@tool(description="Strategic reflection tool for research planning") +def think_tool(reflection: str) -> str: + """Tool for strategic reflection on research progress and decision-making. + + Use this tool after each search to analyze results and plan next steps systematically. + This creates a deliberate pause in the research workflow for quality decision-making. + + When to use: + - After receiving search results: What key information did I find? + - Before deciding next steps: Do I have enough to answer comprehensively? + - When assessing research gaps: What specific information am I still missing? + - Before concluding research: Can I provide a complete answer now? + + Reflection should address: + 1. Analysis of current findings - What concrete information have I gathered? + 2. Gap assessment - What crucial information is still missing? + 3. Quality evaluation - Do I have sufficient evidence/examples for a good answer? + 4. Strategic decision - Should I continue searching or provide my answer? + + Args: + reflection: Your detailed reflection on research progress, findings, gaps, and next steps + + Returns: + Confirmation that reflection was recorded for decision-making + """ + return f"Reflection recorded: {reflection}" + + +################### +# Tool Utilities +################### + + +def get_research_tools(session_tools: list["BaseTool"]) -> list["BaseTool"]: + """Combine session tools with research-specific tools. + + Args: + session_tools: Tools from the session (MCP tools, Google search, etc.) + + Returns: + List of all tools available to researchers + """ + # Research control tools + research_tools: list["BaseTool"] = [ + tool(ResearchComplete), # Signal research completion + think_tool, # Strategic reflection + ] + + # Add all session tools (MCP tools, Google search if enabled) + research_tools.extend(session_tools) + + return research_tools + + +def get_notes_from_tool_calls(messages: list[MessageLikeRepresentation]) -> list[str]: + """Extract notes/content from tool call messages. + + Args: + messages: List of messages containing tool calls + + Returns: + List of tool message contents as strings + """ + result: list[str] = [] + for tool_msg in filter_messages(messages, include_types="tool"): + content = tool_msg.content + if isinstance(content, str): + result.append(content) + else: + result.append(str(content)) + return result + + +################### +# Message Utilities +################### + + +def remove_up_to_last_ai_message( + messages: list[MessageLikeRepresentation], +) -> list[MessageLikeRepresentation]: + """Truncate message history by removing up to the last AI message. + + This is useful for handling token limit exceeded errors by removing recent context. + + Args: + messages: List of message objects to truncate + + Returns: + Truncated message list up to (but not including) the last AI message + """ + # Search backwards through messages to find the last AI message + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], AIMessage): + # Return everything up to (but not including) the last AI message + return messages[:i] + + # No AI messages found, return original list + return messages + + +def get_tool_message_content(messages: list[MessageLikeRepresentation]) -> str: + """Extract content from all tool and AI messages. + + Args: + messages: List of messages to extract content from + + Returns: + Concatenated content from tool and AI messages + """ + filtered = filter_messages(messages, include_types=["tool", "ai"]) + return "\n".join(str(message.content) for message in filtered) + + +################### +# Date Utilities +################### + + +def get_today_str() -> str: + """Get current date formatted for display in prompts and outputs. + + Returns: + Human-readable date string in format like 'Mon Jan 15, 2024' + """ + now = datetime.now() + return f"{now:%a} {now:%b} {now.day}, {now:%Y}" + + +################### +# Buffer Utilities +################### + + +def get_buffer_string(messages: list[MessageLikeRepresentation]) -> str: + """Convert messages to a string buffer for prompt formatting. + + Args: + messages: List of messages to convert + + Returns: + Formatted string representation of messages + """ + buffer_parts: list[str] = [] + for msg in messages: + # Handle actual message objects (most common case) + if isinstance(msg, BaseMessage): + role = msg.type + content = msg.content if isinstance(msg.content, str) else str(msg.content) + # Handle string messages + elif isinstance(msg, str): + role = "unknown" + content = msg + # Handle tuple messages (role, content) + elif isinstance(msg, tuple) and len(msg) == 2: + role = str(msg[0]) + content = str(msg[1]) + # Handle dict messages + elif isinstance(msg, dict): + role = str(msg.get("type", msg.get("role", "unknown"))) + content = str(msg.get("content", msg)) + # Handle list and other types + else: + role = "unknown" + content = str(msg) + + buffer_parts.append(f"{role}: {content}") + + return "\n".join(buffer_parts) + + +__all__ = [ + "think_tool", + "get_research_tools", + "get_notes_from_tool_calls", + "remove_up_to_last_ai_message", + "get_tool_message_content", + "get_today_str", + "get_buffer_string", +] diff --git a/service/app/agents/system/react/__init__.py b/service/app/agents/system/react/__init__.py new file mode 100644 index 00000000..a7814bd4 --- /dev/null +++ b/service/app/agents/system/react/__init__.py @@ -0,0 +1,19 @@ +""" +ReAct System Agent + +The default tool-calling agent using the ReAct (Reasoning + Acting) pattern. +This is the standard agent for chat conversations with tool use. + +Usage: + from app.agents.system.react import ReActAgent + + agent = ReActAgent() + agent.configure(llm=my_llm, tools=[tool1, tool2]) + graph = agent.build_graph() + + result = await graph.ainvoke({"messages": [HumanMessage(content="Hello")]}) +""" + +from .agent import ReActAgent + +__all__ = ["ReActAgent"] diff --git a/service/app/agents/system/react/agent.py b/service/app/agents/system/react/agent.py new file mode 100644 index 00000000..3d25a4e9 --- /dev/null +++ b/service/app/agents/system/react/agent.py @@ -0,0 +1,204 @@ +""" +ReAct System Agent - Default tool-calling agent for chat conversations. + +This module provides the standard ReAct (Reasoning + Acting) agent that uses +LangChain's create_agent for tool-calling conversations. + +As a system agent, ReAct: +- Is the default agent when no agent is specified +- Can be forked by users to customize behavior +- Exports its configuration for JSON-based customization +""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain.agents import create_agent +from langchain_core.tools import BaseTool +from langgraph.graph.state import CompiledStateGraph + +from app.agents.components import BaseComponent +from app.agents.system.base import BaseSystemAgent +from app.schemas.graph_config import ( + GraphConfig, + GraphEdgeConfig, + GraphNodeConfig, + GraphStateSchema, + LLMNodeConfig, + NodeType, + ReducerType, + StateFieldSchema, +) + +logger = logging.getLogger(__name__) + + +class ReActAgent(BaseSystemAgent): + """ + Default ReAct agent for tool-calling conversations. + + Uses LangGraph's prebuilt create_react_agent which implements + the ReAct pattern: Reasoning + Acting with tool calls. + + This agent: + - Processes user messages + - Decides whether to use tools or respond directly + - Executes tools and incorporates results + - Generates final responses + + Supports combining provider-side tools (like Google Search) with + client-side tools (like MCP tools) by binding them together. + + As a system agent, ReAct is available to all users and can be + forked to create customized versions. + """ + + SYSTEM_KEY = "react" + + # Additional configuration options + system_prompt: str + google_search_enabled: bool + + def __init__( + self, + system_prompt: str = "", + google_search_enabled: bool = False, + ) -> None: + """ + Initialize the ReAct agent. + + Args: + system_prompt: System prompt to guide agent behavior + google_search_enabled: Enable Google's builtin web search + """ + super().__init__( + name="ReAct Agent", + description="Default tool-calling agent using ReAct pattern for reasoning and acting", + version="1.0.0", + capabilities=["tool-calling", "reasoning", "multi-turn-conversation"], + tags=["default", "react", "chat", "tool-calling"], + author="Xyzen", + ) + self.system_prompt = system_prompt + self.google_search_enabled = google_search_enabled + + def build_graph(self) -> CompiledStateGraph[Any, None, Any, Any]: + """ + Build the ReAct agent graph using LangGraph's prebuilt implementation. + + When google_search_enabled is True, binds both the google_search + provider tool and MCP tools together to the model. + + Returns: + Compiled StateGraph ready for execution + """ + if not self.llm: + raise RuntimeError("LLM not configured. Call configure() first.") + + tools = self.tools or [] + logger.info(f"Building ReAct agent with {len(tools)} tools, google_search={self.google_search_enabled}") + + # Combine all tools for binding + # MCP tools (client-side) are passed as-is + # Provider-side web search is bound at model creation time. + all_tools: list[BaseTool] = list(tools) + + # Use LangChain's create_agent (replacement for deprecated create_react_agent) + # Pass all tools together so they're bound in a single call + agent: CompiledStateGraph[Any, None, Any, Any] = create_agent( + model=self.llm, + tools=all_tools, + system_prompt=self.system_prompt if self.system_prompt else None, + ) + + logger.debug("ReAct agent graph built successfully") + return agent + + def get_state_schema(self) -> dict[str, Any]: + """ + Return the state schema for ReAct agent. + + The prebuilt create_react_agent uses a standard messages-based schema. + + Returns: + State schema definition + """ + return { + "messages": "list[BaseMessage] - Conversation messages", + } + + def export_graph_config(self) -> GraphConfig: + """ + Export the ReAct agent's workflow as a JSON GraphConfig. + + Note: The actual ReAct implementation uses LangGraph's prebuilt + create_agent, so this export is a simplified representation + that captures the essential structure. + + Returns: + GraphConfig representing this agent's workflow + """ + return GraphConfig( + version="1.0", + state_schema=GraphStateSchema( + fields={ + "messages": StateFieldSchema( + type="list", + description="Conversation messages", + reducer=ReducerType.APPEND, + ), + } + ), + nodes=[ + GraphNodeConfig( + id="agent", + name="ReAct Agent", + type=NodeType.LLM, + description="Process messages and decide on tool use or response", + llm_config=LLMNodeConfig( + prompt_template=self.system_prompt or "You are a helpful assistant.", + output_key="response", + tools_enabled=True, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + metadata={ + "author": "Xyzen", + "version": "1.0.0", + "description": "Default ReAct agent for tool-calling conversations", + "note": "This is a simplified representation. The actual implementation uses LangGraph's prebuilt create_agent.", + }, + ) + + def get_exported_components(self) -> list[BaseComponent]: + """ + Return list of reusable components this agent provides. + + The ReAct agent uses LangGraph's prebuilt implementation, + so it doesn't export custom components. + + Returns: + Empty list (no custom components) + """ + return [] + + def supports_streaming(self) -> bool: + """ReAct agent supports streaming.""" + return True + + def get_required_tools(self) -> list[str]: + """Return names of tools configured for this agent.""" + if self.tools: + return [tool.name for tool in self.tools] + return [] + + +# Export +__all__ = ["ReActAgent"] diff --git a/service/app/agents/types.py b/service/app/agents/types.py new file mode 100644 index 00000000..245df00a --- /dev/null +++ b/service/app/agents/types.py @@ -0,0 +1,154 @@ +""" +Agent Type Definitions - Type aliases and protocols for the agent system. + +This module provides type definitions for: +- StateGraph and CompiledStateGraph with proper type parameters +- Node function signatures +- LLM factory callables +- System agent metadata structures +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypedDict, TypeVar + +from langchain_core.messages import BaseMessage +from langgraph.graph.state import CompiledStateGraph, StateGraph +from pydantic import BaseModel +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + + +# ============================================================================= +# State Types +# ============================================================================= + + +class BaseGraphState(TypedDict, total=False): + """ + Minimum fields present in all graph states. + + All dynamic graph states will have at least these fields, + though additional fields may be added at runtime. + """ + + messages: list[BaseMessage] + execution_context: dict[str, Any] + + +# Type alias for dynamic state dictionaries +# Used when the state schema is determined at runtime +StateDict = dict[str, Any] + + +# ============================================================================= +# StateGraph Type Aliases +# ============================================================================= + +# TypeVar for state schema types +S = TypeVar("S", bound=BaseModel) + +# For agents with known state types (e.g., DeepResearchState) +# Usage: TypedStateGraph[MyState] or TypedCompiledGraph[MyState] +TypedStateGraph = StateGraph[S, None, S, S] +TypedCompiledGraph = CompiledStateGraph[S, None, S, S] + +# For dynamic/runtime-determined state types (graph builder) +# These use BaseModel as a placeholder since the actual type is created dynamically +DynamicStateGraph = StateGraph[BaseModel, None, BaseModel, BaseModel] +DynamicCompiledGraph = CompiledStateGraph[BaseModel, None, BaseModel, BaseModel] + + +# ============================================================================= +# Node Function Types +# ============================================================================= + +# Async node function: takes state dict, returns partial state update +NodeFunction = Callable[[StateDict], Awaitable[StateDict]] + +# Sync routing function: takes state dict, returns next node name +RouterFunction = Callable[[StateDict], str] + +# Generic callable for node actions (used by LangGraph's add_node) +NodeAction = Callable[[StateDict], Awaitable[StateDict]] | Callable[[StateDict], StateDict] + + +# ============================================================================= +# LLM Factory Type +# ============================================================================= + +# Factory function that creates LLM instances with optional overrides +# Signature: async def factory(model=None, temperature=None, ...) -> BaseChatModel +LLMFactory = Callable[..., Awaitable["BaseChatModel"]] + + +# ============================================================================= +# System Agent Metadata Types +# ============================================================================= + + +class AgentMetadata(TypedDict): + """Metadata returned by BaseBuiltinGraphAgent.get_metadata().""" + + name: str + description: str + version: str + capabilities: list[str] + tags: list[str] + author: NotRequired[str | None] + license: NotRequired[str | None] + + +class ComponentMetadataDict(TypedDict): + """Serialized component metadata.""" + + key: str + name: str + description: str + component_type: str + version: str + author: NotRequired[str | None] + tags: NotRequired[list[str]] + + +class SystemAgentInfo(TypedDict): + """ + Info returned by list_available_system_agents(). + + Contains metadata about a system agent including its key, + metadata, whether it can be forked, and exported components. + """ + + key: str + metadata: AgentMetadata + forkable: bool + components: list[ComponentMetadataDict] + error: NotRequired[str] + + +# ============================================================================= +# Exports +# ============================================================================= + +__all__ = [ + # State types + "BaseGraphState", + "StateDict", + # StateGraph aliases + "S", + "TypedStateGraph", + "TypedCompiledGraph", + "DynamicStateGraph", + "DynamicCompiledGraph", + # Function types + "NodeFunction", + "RouterFunction", + "NodeAction", + "LLMFactory", + # Metadata types + "AgentMetadata", + "ComponentMetadataDict", + "SystemAgentInfo", +] diff --git a/service/app/api/v1/agents.py b/service/app/api/v1/agents.py index 263b0966..83a2f960 100644 --- a/service/app/api/v1/agents.py +++ b/service/app/api/v1/agents.py @@ -21,6 +21,7 @@ from app.core.system_agent import SystemAgentManager from app.infra.database import get_session from app.middleware.auth import get_current_user +from app.agents.types import SystemAgentInfo from app.models.agent import AgentCreate, AgentRead, AgentReadWithDetails, AgentScope, AgentUpdate from app.repos import AgentRepository, KnowledgeSetRepository, ProviderRepository from app.repos.agent_marketplace import AgentMarketplaceRepository @@ -132,6 +133,79 @@ async def get_agents( return agents_with_details +@router.get("/templates/system", response_model=list[SystemAgentInfo]) +async def get_system_agent_templates( + user: str = Depends(get_current_user), +) -> list[SystemAgentInfo]: + """ + Get all available system agent templates that users can add. + + Returns a list of system agents (like ReAct, Deep Research) that users + can create instances of. Each template includes metadata about the agent's + capabilities and purpose. + + Args: + user: Authenticated user ID (injected by dependency) + + Returns: + list[SystemAgentInfo]: List of available system agent templates + """ + # Lazy import to avoid circular dependency + from app.agents.factory import list_available_system_agents + + return list_available_system_agents() + + +@router.post("/from-template/{system_key}", response_model=AgentRead) +async def create_agent_from_template( + system_key: str, + user_id: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> AgentRead: + """ + Create a new agent from a system agent template. + + This creates a user agent with the system agent's graph_config pre-populated, + allowing users to use or customize system agents like Deep Research. + + Args: + system_key: Key of the system agent template (e.g., "react", "deep_research") + user_id: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + AgentRead: The newly created agent with graph_config from the template + + Raises: + HTTPException: 404 if system agent template not found + """ + from app.agents.system import system_agent_registry + + # Get the system agent class + agent_class = system_agent_registry.get_class(system_key) + if not agent_class: + raise HTTPException(status_code=404, detail=f"System agent template '{system_key}' not found") + + # Create an instance and get the forkable config (includes graph_config) + system_agent = agent_class() + forkable_config = system_agent.get_forkable_config() + + # Create the agent with the exported graph_config + agent_data = AgentCreate( + scope=AgentScope.USER, + name=forkable_config.get("name", system_agent.name), + description=forkable_config.get("description", system_agent.description), + tags=forkable_config.get("tags", []), + graph_config=forkable_config.get("graph_config"), + ) + + agent_repo = AgentRepository(db) + created_agent = await agent_repo.create_agent(agent_data, user_id) + + await db.commit() + return AgentRead(**created_agent.model_dump()) + + @router.get("/{agent_id}", response_model=AgentReadWithDetails) async def get_agent( agent_id: UUID, @@ -212,8 +286,8 @@ async def update_agent( if provider.user_id != agent.user_id and not provider.is_system: raise HTTPException(status_code=403, detail="Provider access denied") - # Validate knowledge_set_id if being updated - if agent_data.knowledge_set_id is not None: + # Validate knowledge_set_id only if it's being changed to a different value + if agent_data.knowledge_set_id is not None and agent_data.knowledge_set_id != agent.knowledge_set_id: knowledge_set_repo = KnowledgeSetRepository(db) knowledge_set = await knowledge_set_repo.get_knowledge_set_by_id(agent_data.knowledge_set_id) if not knowledge_set or knowledge_set.user_id != user_id or knowledge_set.is_deleted: diff --git a/service/app/api/v1/files.py b/service/app/api/v1/files.py index b7c71c7d..98321b4b 100644 --- a/service/app/api/v1/files.py +++ b/service/app/api/v1/files.py @@ -933,7 +933,7 @@ async def bulk_delete_files( ) -def render_pptx_table(table: "pptx.table.Table | None", slide_width_pt: float) -> BytesIO: # type: ignore # noqa: F821 +def render_pptx_table(table: Any | None, slide_width_pt: float) -> BytesIO: """ Render a PowerPoint table to PNG image with proper formatting. @@ -948,8 +948,8 @@ def render_pptx_table(table: "pptx.table.Table | None", slide_width_pt: float) - return BytesIO() # Table parameters - rows = table.rows # type: ignore - cols = table.columns # type: ignore + rows = table.rows + cols = table.columns num_rows = len(rows) num_cols = len(cols) @@ -1307,15 +1307,6 @@ def convert_docx_to_pdf_bytes(docx_data: bytes) -> BytesIO: """ import io - try: - import fitz # pymupdf - except ImportError: - try: - import pymupdf as fitz # noqa: F401 - except ImportError: - logger.error("pymupdf not installed, cannot convert DOCX to PDF") - raise ImportError("pymupdf is required for DOCX to PDF conversion") - try: # Try using libreoffice via command line if available import subprocess @@ -1500,15 +1491,6 @@ def convert_xlsx_to_pdf_bytes(xlsx_data: bytes) -> BytesIO: logger.warning(f"Could not determine orientation, defaulting to landscape: {e}") use_landscape = True - try: - import fitz # pymupdf - except ImportError: - try: - import pymupdf as fitz # noqa: F401 - except ImportError: - logger.error("pymupdf not installed, cannot convert XLSX to PDF") - raise ImportError("pymupdf is required for XLSX to PDF conversion") - try: # Try using libreoffice via command line if available import subprocess @@ -1600,7 +1582,7 @@ def convert_xlsx_to_pdf_bytes(xlsx_data: bytes) -> BytesIO: data = [] col_widths = {} - for row_idx, row in enumerate(ws.iter_rows(values_only=True)): + for row in ws.iter_rows(values_only=True): row_data = [] for col_idx, cell_value in enumerate(row): # Convert value to string diff --git a/service/app/api/ws/v1/chat.py b/service/app/api/ws/v1/chat.py index 0a1fc390..db719357 100644 --- a/service/app/api/ws/v1/chat.py +++ b/service/app/api/ws/v1/chat.py @@ -14,7 +14,7 @@ from app.middleware.auth import AuthContext, get_auth_context_websocket from app.models.message import MessageCreate from app.repos import FileRepository, MessageRepository, SessionRepository, TopicRepository -from app.schemas.chat_events import ChatClientEventType, ChatEventType +from app.schemas.chat_event_types import ChatClientEventType, ChatEventType # from app.core.celery_app import celery_app # Not needed directly if we import the task from app.tasks.chat import process_chat_message diff --git a/service/app/core/chat/agent_event_handler.py b/service/app/core/chat/agent_event_handler.py new file mode 100644 index 00000000..efb0e9f4 --- /dev/null +++ b/service/app/core/chat/agent_event_handler.py @@ -0,0 +1,417 @@ +""" +Agent Event Handler - Utilities for emitting agent execution events. + +This module provides a clean interface for emitting structured agent +execution events during graph-based agent execution. +""" + +from __future__ import annotations + +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from app.schemas.agent_event_payloads import ( + AgentEndData, + AgentErrorData, + AgentExecutionContext, + AgentStartData, + IterationEndData, + IterationStartData, + NodeEndData, + NodeStartData, + PhaseEndData, + PhaseStartData, + ProgressUpdateData, + StateUpdateData, + SubagentEndData, + SubagentStartData, +) +from app.schemas.chat_event_payloads import ( + AgentEndEvent, + AgentErrorEvent, + AgentStartEvent, + IterationEndEvent, + IterationStartEvent, + NodeEndEvent, + NodeStartEvent, + PhaseEndEvent, + PhaseStartEvent, + ProgressUpdateEvent, + StateUpdateEvent, + SubagentEndEvent, + SubagentStartEvent, +) +from app.schemas.chat_event_types import ChatEventType + + +@dataclass +class AgentEventContext: + """ + Maintains execution context for event emission. + + This class tracks the current execution state and provides methods + for creating child contexts for subagent execution. + """ + + agent_id: str + agent_name: str + agent_type: str # actual system key (e.g., "react", "deep_research") or "graph" + execution_id: str = field(default_factory=lambda: f"exec_{uuid.uuid4().hex[:12]}") + parent_execution_id: str | None = None + depth: int = 0 + execution_path: list[str] = field(default_factory=list) + started_at: float = field(default_factory=time.time) + current_node: str | None = None + current_phase: str | None = None + + def __post_init__(self) -> None: + """Initialize execution path if empty.""" + if not self.execution_path: + self.execution_path = [self.agent_name] + + def to_context_dict(self) -> AgentExecutionContext: + """Convert to AgentExecutionContext dictionary for events.""" + ctx: AgentExecutionContext = { + "agent_id": self.agent_id, + "agent_name": self.agent_name, + "agent_type": self.agent_type, + "execution_id": self.execution_id, + "depth": self.depth, + "execution_path": self.execution_path, + "started_at": self.started_at, + "elapsed_ms": int((time.time() - self.started_at) * 1000), + } + + if self.parent_execution_id: + ctx["parent_execution_id"] = self.parent_execution_id + if self.current_node: + ctx["current_node"] = self.current_node + if self.current_phase: + ctx["current_phase"] = self.current_phase + + return ctx + + def child_context( + self, + subagent_id: str, + subagent_name: str, + subagent_type: str = "subagent", + ) -> "AgentEventContext": + """ + Create a child context for subagent execution. + + Args: + subagent_id: UUID of the subagent + subagent_name: Name of the subagent + subagent_type: Type of the subagent + + Returns: + New AgentEventContext for the subagent + """ + return AgentEventContext( + agent_id=subagent_id, + agent_name=subagent_name, + agent_type=subagent_type, + execution_id=f"{self.execution_id}:{subagent_id[:8]}", + parent_execution_id=self.execution_id, + depth=self.depth + 1, + execution_path=self.execution_path + [subagent_name], + started_at=time.time(), + ) + + def set_current_node(self, node_id: str | None) -> None: + """Update the current node being executed.""" + self.current_node = node_id + + def set_current_phase(self, phase: str | None) -> None: + """Update the current phase.""" + self.current_phase = phase + + +class AgentEventHandler: + """ + Static utility class for creating agent execution events. + + All methods return AgentEvent-typed envelopes that can be yielded + from the agent execution stream. + """ + + # === Agent Lifecycle === + + @staticmethod + def emit_agent_start( + ctx: AgentEventContext, + total_nodes: int | None = None, + estimated_duration_ms: int | None = None, + ) -> AgentStartEvent: + """Emit AGENT_START event.""" + data: AgentStartData = {"context": ctx.to_context_dict()} + if total_nodes is not None: + data["total_nodes"] = total_nodes + if estimated_duration_ms is not None: + data["estimated_duration_ms"] = estimated_duration_ms + + return {"type": ChatEventType.AGENT_START, "data": data} + + @staticmethod + def emit_agent_end( + ctx: AgentEventContext, + status: str, + output_summary: str | None = None, + ) -> AgentEndEvent: + """Emit AGENT_END event.""" + data: AgentEndData = { + "context": ctx.to_context_dict(), + "status": status, + "duration_ms": int((time.time() - ctx.started_at) * 1000), + } + if output_summary: + data["output_summary"] = output_summary + + return {"type": ChatEventType.AGENT_END, "data": data} + + @staticmethod + def emit_agent_error( + ctx: AgentEventContext, + error: Exception, + recoverable: bool = False, + node_id: str | None = None, + ) -> AgentErrorEvent: + """Emit AGENT_ERROR event.""" + data: AgentErrorData = { + "context": ctx.to_context_dict(), + "error_type": type(error).__name__, + "error_message": str(error), + "recoverable": recoverable, + } + if node_id: + data["node_id"] = node_id + + return {"type": ChatEventType.AGENT_ERROR, "data": data} + + # === Phase Events === + + @staticmethod + def emit_phase_start( + ctx: AgentEventContext, + phase_id: str, + phase_name: str, + description: str | None = None, + expected_duration_ms: int | None = None, + ) -> PhaseStartEvent: + """Emit PHASE_START event.""" + ctx.set_current_phase(phase_id) + + data: PhaseStartData = { + "phase_id": phase_id, + "phase_name": phase_name, + "context": ctx.to_context_dict(), + } + if description: + data["description"] = description + if expected_duration_ms is not None: + data["expected_duration_ms"] = expected_duration_ms + + return {"type": ChatEventType.PHASE_START, "data": data} + + @staticmethod + def emit_phase_end( + ctx: AgentEventContext, + phase_id: str, + phase_name: str, + status: str, + start_time: float, + output_summary: str | None = None, + ) -> PhaseEndEvent: + """Emit PHASE_END event.""" + data: PhaseEndData = { + "phase_id": phase_id, + "phase_name": phase_name, + "status": status, + "duration_ms": int((time.time() - start_time) * 1000), + "context": ctx.to_context_dict(), + } + if output_summary: + data["output_summary"] = output_summary + + return {"type": ChatEventType.PHASE_END, "data": data} + + # === Node Events === + + @staticmethod + def emit_node_start( + ctx: AgentEventContext, + node_id: str, + node_name: str, + node_type: str, + input_summary: str | None = None, + ) -> NodeStartEvent: + """Emit NODE_START event.""" + ctx.set_current_node(node_id) + + data: NodeStartData = { + "node_id": node_id, + "node_name": node_name, + "node_type": node_type, + "context": ctx.to_context_dict(), + } + if input_summary: + data["input_summary"] = input_summary + + return {"type": ChatEventType.NODE_START, "data": data} + + @staticmethod + def emit_node_end( + ctx: AgentEventContext, + node_id: str, + node_name: str, + node_type: str, + status: str, + start_time: float, + output_summary: str | None = None, + ) -> NodeEndEvent: + """Emit NODE_END event.""" + data: NodeEndData = { + "node_id": node_id, + "node_name": node_name, + "node_type": node_type, + "status": status, + "duration_ms": int((time.time() - start_time) * 1000), + "context": ctx.to_context_dict(), + } + if output_summary: + data["output_summary"] = output_summary + + return {"type": ChatEventType.NODE_END, "data": data} + + # === Subagent Events === + + @staticmethod + def emit_subagent_start( + ctx: AgentEventContext, + subagent_id: str, + subagent_name: str, + subagent_type: str = "graph", + input_summary: str | None = None, + ) -> SubagentStartEvent: + """Emit SUBAGENT_START event.""" + data: SubagentStartData = { + "subagent_id": subagent_id, + "subagent_name": subagent_name, + "subagent_type": subagent_type, + "context": ctx.to_context_dict(), + } + if input_summary: + data["input_summary"] = input_summary + + return {"type": ChatEventType.SUBAGENT_START, "data": data} + + @staticmethod + def emit_subagent_end( + ctx: AgentEventContext, + subagent_id: str, + subagent_name: str, + status: str, + start_time: float, + output_summary: str | None = None, + ) -> SubagentEndEvent: + """Emit SUBAGENT_END event.""" + data: SubagentEndData = { + "subagent_id": subagent_id, + "subagent_name": subagent_name, + "status": status, + "duration_ms": int((time.time() - start_time) * 1000), + "context": ctx.to_context_dict(), + } + if output_summary: + data["output_summary"] = output_summary + + return {"type": ChatEventType.SUBAGENT_END, "data": data} + + # === Progress Events === + + @staticmethod + def emit_progress( + ctx: AgentEventContext, + percent: int, + message: str, + details: dict[str, Any] | None = None, + ) -> ProgressUpdateEvent: + """Emit PROGRESS_UPDATE event.""" + data: ProgressUpdateData = { + "progress_percent": max(0, min(100, percent)), + "message": message, + "context": ctx.to_context_dict(), + } + if details: + data["details"] = details + + return {"type": ChatEventType.PROGRESS_UPDATE, "data": data} + + # === Iteration Events === + + @staticmethod + def emit_iteration_start( + ctx: AgentEventContext, + iteration_number: int, + max_iterations: int, + reason: str | None = None, + ) -> IterationStartEvent: + """Emit ITERATION_START event.""" + data: IterationStartData = { + "iteration_number": iteration_number, + "max_iterations": max_iterations, + "context": ctx.to_context_dict(), + } + if reason: + data["reason"] = reason + + return {"type": ChatEventType.ITERATION_START, "data": data} + + @staticmethod + def emit_iteration_end( + ctx: AgentEventContext, + iteration_number: int, + will_continue: bool, + reason: str | None = None, + ) -> IterationEndEvent: + """Emit ITERATION_END event.""" + data: IterationEndData = { + "iteration_number": iteration_number, + "will_continue": will_continue, + "context": ctx.to_context_dict(), + } + if reason: + data["reason"] = reason + + return {"type": ChatEventType.ITERATION_END, "data": data} + + # === State Events === + + @staticmethod + def emit_state_update( + ctx: AgentEventContext, + updated_keys: list[str], + summary: dict[str, str], + ) -> StateUpdateEvent: + """ + Emit STATE_UPDATE event. + + Only include non-sensitive, summarized state information. + """ + data: StateUpdateData = { + "updated_keys": updated_keys, + "summary": summary, + "context": ctx.to_context_dict(), + } + + return {"type": ChatEventType.STATE_UPDATE, "data": data} + + +# Export +__all__ = [ + "AgentEventContext", + "AgentEventHandler", +] diff --git a/service/app/core/chat/history.py b/service/app/core/chat/history.py index c898fe0c..3252956d 100644 --- a/service/app/core/chat/history.py +++ b/service/app/core/chat/history.py @@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Any from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_core.messages.tool import ToolMessage +from langchain_core.messages.tool import ToolCall, ToolMessage from sqlmodel.ext.asyncio.session import AsyncSession -from app.schemas.chat_events import ChatEventType +from app.schemas.chat_event_types import ChatEventType if TYPE_CHECKING: from app.models.topic import Topic as TopicModel @@ -112,6 +112,11 @@ async def _build_assistant_message(db: AsyncSession, message: Any, content: str) """Build an AIMessage with optional multimodal content (e.g., generated images).""" from app.core.chat.multimodal import process_message_files + # Extract agent_metadata if present + additional_kwargs: dict[str, Any] = {} + if hasattr(message, "agent_metadata") and message.agent_metadata: + additional_kwargs["agent_state"] = message.agent_metadata + try: file_contents = await process_message_files(db, message.id) if file_contents: @@ -121,13 +126,13 @@ async def _build_assistant_message(db: AsyncSession, message: Any, content: str) if content: multimodal_content.append({"type": "text", "text": content}) multimodal_content.extend(file_contents) - return AIMessage(content=multimodal_content) # pyright: ignore + return AIMessage(content=multimodal_content, additional_kwargs=additional_kwargs) # pyright: ignore - return AIMessage(content=content) + return AIMessage(content=content, additional_kwargs=additional_kwargs) except Exception as e: logger.error(f"Failed to process files for assistant message {message.id}: {e}", exc_info=True) - return AIMessage(content=content) + return AIMessage(content=content, additional_kwargs=additional_kwargs) def _build_tool_messages( @@ -151,7 +156,7 @@ def _build_tool_messages( formatted_content = json.loads(content) if formatted_content.get("event") == ChatEventType.TOOL_CALL_REQUEST: - tool_call = { + tool_call: ToolCall = { "name": formatted_content["name"], "args": formatted_content["arguments"], "id": formatted_content["id"], @@ -159,12 +164,12 @@ def _build_tool_messages( if num_tool_calls == 0: # First tool call - create new AIMessage - message = AIMessage(content=[], tool_calls=[tool_call]) # type: ignore[arg-type] + message = AIMessage(content=[], tool_calls=[tool_call]) return message, num_tool_calls + 1 else: # Subsequent tool call - append to existing AIMessage if history and isinstance(history[-1], AIMessage) and hasattr(history[-1], "tool_calls"): - history[-1].tool_calls.append(tool_call) # type: ignore[arg-type] + history[-1].tool_calls.append(tool_call) return None, num_tool_calls + 1 elif formatted_content.get("event") == ChatEventType.TOOL_CALL_RESPONSE: diff --git a/service/app/core/chat/langchain.py b/service/app/core/chat/langchain.py index 3b2d93b1..c9eba537 100644 --- a/service/app/core/chat/langchain.py +++ b/service/app/core/chat/langchain.py @@ -18,10 +18,12 @@ from app.core.prompts import build_system_prompt from app.core.providers import get_user_provider_manager from app.models.topic import Topic as TopicModel -from app.schemas.chat_event_types import StreamingEvent +from app.schemas.chat_event_payloads import StreamingEvent +from .agent_event_handler import AgentEventContext from .history import load_conversation_history from .stream_handlers import ( + AgentEventStreamHandler, CitationExtractor, GeneratedFileHandler, StreamContext, @@ -94,7 +96,7 @@ async def get_ai_response_stream_langchain_legacy( try: # Create LangChain agent - langchain_agent = await _create_langchain_agent( + langchain_agent, event_ctx = await _create_langchain_agent( db=db, agent=agent, topic=topic, @@ -109,6 +111,7 @@ async def get_ai_response_stream_langchain_legacy( stream_id=f"stream_{int(asyncio.get_event_loop().time() * 1000)}", db=db, user_id=user_id, + event_ctx=event_ctx, ) # Load conversation history @@ -179,9 +182,9 @@ async def _create_langchain_agent( provider_id: str | None, model_name: str | None, system_prompt: str, -) -> CompiledStateGraph: +) -> tuple[CompiledStateGraph[Any, None, Any, Any], AgentEventContext]: """Create and configure the LangChain agent using the agent factory.""" - return await create_chat_agent( + graph, event_ctx = await create_chat_agent( db=db, agent_config=agent, topic=topic, @@ -190,18 +193,30 @@ async def _create_langchain_agent( model_name=model_name, system_prompt=system_prompt, ) + return graph, event_ctx async def _process_agent_stream( - agent: CompiledStateGraph, + agent: CompiledStateGraph[Any, None, Any, Any], history_messages: list[Any], ctx: StreamContext, ) -> AsyncGenerator[StreamingEvent, None]: """Process the agent stream and yield events.""" logger.info("Starting agent.astream with stream_mode=['updates','messages']") logger.info(f"Length of history: {len(history_messages)}") + logger.info(f"[AgentEvent] event_ctx present: {ctx.event_ctx is not None}") + if ctx.event_ctx: + logger.info(f"[AgentEvent] agent_name={ctx.event_ctx.agent_name}, agent_type={ctx.event_ctx.agent_type}") + + # Emit agent_start event + if ctx.event_ctx: + agent_start_event = AgentEventStreamHandler.create_agent_start_event(ctx) + if agent_start_event: + logger.info(f"[AgentEvent] Emitting agent_start for {ctx.event_ctx.agent_name}") + yield agent_start_event chunk_count = 0 + async for chunk in agent.astream({"messages": history_messages}, stream_mode=["updates", "messages"]): chunk_count += 1 try: @@ -215,6 +230,9 @@ async def _process_agent_stream( if isinstance(data, dict): logger.info(f"[Updates] step_names={list(data.keys())}") # Log content of each step for debugging + # NOTE: Node events are primarily emitted from messages mode for accurate timing + # Updates mode handles non-streaming nodes (structured output nodes) + # that were already handled in _handle_updates_mode for step_name, step_data in data.items(): messages = step_data.get("messages", []) if isinstance(step_data, dict) else [] logger.info(f"[Updates/{step_name}] messages_count={len(messages)}") @@ -241,10 +259,24 @@ async def _process_agent_stream( logger.info(f"Stream finished after {chunk_count} chunks, is_streaming={ctx.is_streaming}") + # Emit node_end for the last node + if ctx.current_node: + node_end_event = AgentEventStreamHandler.create_node_end_event(ctx, ctx.current_node) + if node_end_event: + logger.info(f"[AgentEvent] Emitting final node_end for {ctx.current_node}") + yield node_end_event + # Finalize streaming async for event in _finalize_streaming(ctx): yield event + # Emit agent_end event + if ctx.event_ctx and ctx.agent_started: + agent_end_event = AgentEventStreamHandler.create_agent_end_event(ctx, "completed") + if agent_end_event: + logger.info(f"[AgentEvent] Emitting agent_end for {ctx.event_ctx.agent_name}") + yield agent_end_event + async def _handle_updates_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[StreamingEvent, None]: """Handle 'updates' mode events (tool calls, model responses).""" @@ -254,12 +286,40 @@ async def _handle_updates_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[ for step_name, step_data in data.items(): logger.debug("Update step: %s", step_name) + # Skip if step_data is None or not a dict + if not step_data or not isinstance(step_data, dict): + continue + messages = step_data.get("messages", []) if not messages: continue last_message = messages[-1] + # Extract agent_state from AIMessage additional_kwargs (for persistence) + if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs: + msg_agent_state = last_message.additional_kwargs.get("agent_state") + if msg_agent_state: + logger.debug("Extracted agent_state from step '%s': %s", step_name, list(msg_agent_state.keys())) + # Accumulate node outputs across all nodes + if ctx.agent_state is None: + ctx.agent_state = {"node_outputs": {}} + # Merge node outputs + if "node_outputs" in msg_agent_state: + ctx.agent_state.setdefault("node_outputs", {}).update(msg_agent_state["node_outputs"]) + # Track current node + if "current_node" in msg_agent_state: + ctx.agent_state["current_node"] = msg_agent_state["current_node"] + + # Also extract node_metadata for logging + node_metadata = last_message.additional_kwargs.get("node_metadata") + if node_metadata: + logger.info( + "Node completed: %s (is_intermediate=%s)", + node_metadata.get("node_name"), + node_metadata.get("is_intermediate"), + ) + # Tool call request if hasattr(last_message, "tool_calls") and last_message.tool_calls: logger.debug("Detected tool_calls in step '%s'", step_name) @@ -279,6 +339,33 @@ async def _handle_updates_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[ logger.debug("Tool finished in step '%s' id=%s", step_name, tool_call_id) yield ToolEventHandler.create_tool_response_event(tool_call_id, formatted_result) + # Structured output nodes (clarify_with_user, write_research_brief, etc.) + # These nodes use with_structured_output and don't stream normally + # They return clean content in messages, so we emit it as if streamed + elif hasattr(last_message, "content") and step_name in { + "clarify_with_user", + "write_research_brief", + }: + content = last_message.content + if isinstance(content, str) and content: + logger.debug("Structured output from '%s': %s", step_name, content[:100]) + + # Emit node_start if not already current node + if step_name != ctx.current_node: + if ctx.current_node: + node_end_event = AgentEventStreamHandler.create_node_end_event(ctx, ctx.current_node) + if node_end_event: + yield node_end_event + node_start_event = AgentEventStreamHandler.create_node_start_event(ctx, step_name) + if node_start_event: + yield node_start_event + + # Emit the content as if it was streamed (single chunk for the whole message) + if not ctx.is_streaming: + ctx.is_streaming = True + yield StreamingEventHandler.create_streaming_start(ctx.stream_id) + yield StreamingEventHandler.create_streaming_chunk(ctx.stream_id, content) + # Final model response (from 'model' or 'agent' step) elif ( hasattr(last_message, "content") @@ -325,13 +412,35 @@ async def _handle_messages_mode(data: Any, ctx: StreamContext) -> AsyncGenerator ctx.token_count, ) - # Only stream from LLM-related nodes ('model' or 'agent') - # Note: 'model' is used by older LangGraph, 'agent' is used by create_react_agent + # Only skip streaming from tool execution nodes and structured output nodes + # - 'tools' node: where tool calls are executed, not LLM responses + # - 'clarify_with_user': uses structured output (JSON), we only want the final extracted message + # - 'write_research_brief': uses structured output, handled in updates mode + # All other LLM nodes should stream their output normally + SKIP_STREAMING_NODES = {"tools", "clarify_with_user", "write_research_brief"} + + node: str | None = None if isinstance(metadata, dict): node = metadata.get("langgraph_node") or metadata.get("node") - if node and node not in ("model", "agent"): + if node in SKIP_STREAMING_NODES: return + # Emit node events based on streaming metadata (more accurate timing than updates mode) + # This ensures node_start is emitted BEFORE streaming chunks for that node + if node and node != ctx.current_node: + # Emit node_end for previous node + if ctx.current_node: + node_end_event = AgentEventStreamHandler.create_node_end_event(ctx, ctx.current_node) + if node_end_event: + logger.info(f"[AgentEvent/Messages] Emitting node_end for {ctx.current_node}") + yield node_end_event + + # Emit node_start for new node + node_start_event = AgentEventStreamHandler.create_node_start_event(ctx, node) + if node_start_event: + logger.info(f"[AgentEvent/Messages] Emitting node_start for {node}") + yield node_start_event + # Check for thinking content first (from reasoning models like Claude, DeepSeek R1, Gemini 3) thinking_content = ThinkingEventHandler.extract_thinking_content(message_chunk) @@ -376,11 +485,12 @@ async def _finalize_streaming(ctx: StreamContext) -> AsyncGenerator[StreamingEve if ctx.is_streaming: logger.debug( - "Emitting streaming_end for stream_id=%s (total tokens: %d)", + "Emitting streaming_end for stream_id=%s (total tokens: %d, has_agent_state=%s)", ctx.stream_id, ctx.token_count, + ctx.agent_state is not None, ) - yield StreamingEventHandler.create_streaming_end(ctx.stream_id) + yield StreamingEventHandler.create_streaming_end(ctx.stream_id, ctx.agent_state) # Emit token usage if ctx.total_tokens > 0 or ctx.total_input_tokens > 0 or ctx.total_output_tokens > 0: diff --git a/service/app/core/chat/stream_handlers.py b/service/app/core/chat/stream_handlers.py index 7a518e20..f410825a 100644 --- a/service/app/core/chat/stream_handlers.py +++ b/service/app/core/chat/stream_handlers.py @@ -10,13 +10,14 @@ import asyncio import base64 import logging +import time from dataclasses import dataclass, field from io import BytesIO -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Mapping, Sequence, cast from langchain_core.messages import AIMessage -from app.schemas.chat_event_types import ( +from app.schemas.chat_event_payloads import ( CitationData, GeneratedFileInfo, GeneratedFilesData, @@ -32,11 +33,12 @@ ToolCallRequestData, ToolCallResponseData, ) -from app.schemas.chat_events import ChatEventType, ProcessingStatus, ToolCallStatus +from app.schemas.chat_event_types import ChatEventType, ProcessingStatus, ToolCallStatus if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession + from app.core.chat.agent_event_handler import AgentEventContext from app.models.file import File logger = logging.getLogger(__name__) @@ -52,6 +54,7 @@ class StreamContext: stream_id: str db: "AsyncSession" user_id: str + event_ctx: "AgentEventContext | None" = None is_streaming: bool = False assistant_buffer: list[str] = field(default_factory=list) token_count: int = 0 @@ -61,6 +64,13 @@ class StreamContext: # Thinking/reasoning content state is_thinking: bool = False thinking_buffer: list[str] = field(default_factory=list) + # Agent execution state + agent_started: bool = False + current_node: str | None = None + agent_start_time: float = 0.0 + node_start_time: float = 0.0 + # Agent state metadata (for persistence) + agent_state: dict[str, Any] | None = None class ToolEventHandler: @@ -126,12 +136,14 @@ def create_streaming_chunk(stream_id: str, content: str) -> StreamingEvent: return {"type": ChatEventType.STREAMING_CHUNK, "data": data} @staticmethod - def create_streaming_end(stream_id: str) -> StreamingEvent: + def create_streaming_end(stream_id: str, agent_state: dict[str, Any] | None = None) -> StreamingEvent: """Create streaming end event.""" data: StreamingEndData = { "id": stream_id, "created_at": asyncio.get_event_loop().time(), } + if agent_state: + data["agent_state"] = agent_state return {"type": ChatEventType.STREAMING_END, "data": data} @staticmethod @@ -224,9 +236,9 @@ def extract_thinking_content(message_chunk: Any) -> str | None: # Check response_metadata for thinking content if hasattr(message_chunk, "response_metadata"): - metadata = message_chunk.response_metadata - if isinstance(metadata, dict): - # Gemini 3 uses "reasoning" key + metadata_raw = message_chunk.response_metadata + if isinstance(metadata_raw, dict): + metadata: dict[str, Any] = metadata_raw thinking = ( metadata.get("thinking") or metadata.get("reasoning_content") @@ -234,7 +246,8 @@ def extract_thinking_content(message_chunk: Any) -> str | None: or metadata.get("thoughts") ) if thinking: - logger.debug("Found thinking in response_metadata: %s", list(metadata.keys())) + keys_str = ", ".join(map(str, metadata.keys())) + logger.debug("Found thinking in response_metadata keys: %s", keys_str) return thinking return None @@ -244,7 +257,7 @@ class CitationExtractor: """Extract citations from LLM response metadata.""" @staticmethod - def extract_citations(response_metadata: dict[str, Any]) -> list[CitationData]: + def extract_citations(response_metadata: Mapping[str, Any] | None) -> list[CitationData]: """ Extract citations from response metadata. @@ -258,7 +271,7 @@ def extract_citations(response_metadata: dict[str, Any]) -> list[CitationData]: """ citations: list[CitationData] = [] - if not isinstance(response_metadata, dict): + if not response_metadata: return citations # 1. Handle Google Grounding Metadata @@ -275,7 +288,7 @@ def extract_citations(response_metadata: dict[str, Any]) -> list[CitationData]: return CitationExtractor._deduplicate_citations(citations) @staticmethod - def _extract_google_grounding(grounding_metadata: dict[str, Any]) -> list[CitationData]: + def _extract_google_grounding(grounding_metadata: Mapping[str, Any]) -> list[CitationData]: """Extract citations from Google Grounding Metadata.""" citations: list[CitationData] = [] @@ -320,7 +333,7 @@ def _extract_google_grounding(grounding_metadata: dict[str, Any]) -> list[Citati return citations @staticmethod - def _extract_openai_annotations(annotations: list[Any]) -> list[CitationData]: + def _extract_openai_annotations(annotations: Sequence[Any]) -> list[CitationData]: """Extract citations from OpenAI annotations.""" citations: list[CitationData] = [] @@ -429,7 +442,7 @@ async def save_generated_image(image_data: str, user_id: str, db: "AsyncSession" @staticmethod async def process_generated_content( - content: list[Any], user_id: str, db: "AsyncSession" + content: Sequence[Any], user_id: str, db: "AsyncSession" ) -> tuple[list["File"], list[GeneratedFileInfo]]: """ Process multimodal content and save any generated images. @@ -570,15 +583,171 @@ async def process_model_response( # Handle multimodal content (e.g., generated images) if isinstance(content, list): - generated_files, files_data = await GeneratedFileHandler.process_generated_content( - content, ctx.user_id, ctx.db + typed_content = cast(Sequence[Any], content) + _generated_files, files_data = await GeneratedFileHandler.process_generated_content( + typed_content, ctx.user_id, ctx.db ) if files_data: yield GeneratedFileHandler.create_generated_files_event(files_data) # Extract and emit citations - if hasattr(message, "response_metadata"): - citations = CitationExtractor.extract_citations(message.response_metadata) + metadata = cast(Mapping[str, Any] | None, getattr(message, "response_metadata", None)) + if metadata: + citations = CitationExtractor.extract_citations(metadata) if citations: logger.info(f"Emitting {len(citations)} unique search citations") yield CitationExtractor.create_citations_event(citations) + + +class AgentEventStreamHandler: + """ + Handle agent execution events for streaming to frontend. + + Automatically emits agent_start, node_start/end, and agent_end events + based on LangGraph execution metadata. + """ + + @staticmethod + def create_agent_start_event(ctx: StreamContext) -> StreamingEvent | None: + """Create agent_start event from context.""" + if not ctx.event_ctx: + return None + + ctx.agent_started = True + ctx.agent_start_time = time.time() + + return { + "type": ChatEventType.AGENT_START, + "data": { + "context": { + "agent_id": ctx.event_ctx.agent_id, + "agent_name": ctx.event_ctx.agent_name, + "agent_type": ctx.event_ctx.agent_type, + "execution_id": ctx.event_ctx.execution_id, + "depth": 0, + "execution_path": [ctx.event_ctx.agent_name], + "started_at": int(ctx.agent_start_time * 1000), + }, + }, + } + + @staticmethod + def create_agent_end_event(ctx: StreamContext, status: str = "completed") -> StreamingEvent | None: + """Create agent_end event from context.""" + if not ctx.event_ctx or not ctx.agent_started: + return None + + duration_ms = int((time.time() - ctx.agent_start_time) * 1000) + + return { + "type": ChatEventType.AGENT_END, + "data": { + "context": { + "agent_id": ctx.event_ctx.agent_id, + "agent_name": ctx.event_ctx.agent_name, + "agent_type": ctx.event_ctx.agent_type, + "execution_id": ctx.event_ctx.execution_id, + "depth": 0, + "execution_path": [ctx.event_ctx.agent_name], + "started_at": int(ctx.agent_start_time * 1000), + }, + "status": status, + "duration_ms": duration_ms, + }, + } + + @staticmethod + def create_node_start_event(ctx: StreamContext, node_name: str) -> StreamingEvent | None: + """Create node_start event.""" + if not ctx.event_ctx: + return None + + ctx.current_node = node_name + ctx.node_start_time = time.time() + + # Map common node names to more descriptive types + node_type = "llm" + if node_name == "tools": + node_type = "tool" + elif node_name in ("router", "route", "conditional"): + node_type = "router" + + return { + "type": ChatEventType.NODE_START, + "data": { + # Use raw node_name as node_id for frontend compatibility + # (frontend uses node_id for display name lookup) + "node_id": node_name, + "node_name": node_name, + "node_type": node_type, + "context": { + "agent_id": ctx.event_ctx.agent_id, + "agent_name": ctx.event_ctx.agent_name, + "agent_type": ctx.event_ctx.agent_type, + "execution_id": ctx.event_ctx.execution_id, + "depth": 0, + "execution_path": [ctx.event_ctx.agent_name], + "current_node": node_name, + "started_at": int(ctx.agent_start_time * 1000), + }, + }, + } + + @staticmethod + def create_node_end_event(ctx: StreamContext, node_name: str, status: str = "completed") -> StreamingEvent | None: + """Create node_end event.""" + if not ctx.event_ctx: + return None + + duration_ms = int((time.time() - ctx.node_start_time) * 1000) if ctx.node_start_time else 0 + + node_type = "llm" + if node_name == "tools": + node_type = "tool" + elif node_name in ("router", "route", "conditional"): + node_type = "router" + + return { + "type": ChatEventType.NODE_END, + "data": { + # Use raw node_name as node_id for frontend compatibility + "node_id": node_name, + "node_name": node_name, + "node_type": node_type, + "status": status, + "duration_ms": duration_ms, + "context": { + "agent_id": ctx.event_ctx.agent_id, + "agent_name": ctx.event_ctx.agent_name, + "agent_type": ctx.event_ctx.agent_type, + "execution_id": ctx.event_ctx.execution_id, + "depth": 0, + "execution_path": [ctx.event_ctx.agent_name], + "current_node": node_name, + "started_at": int(ctx.agent_start_time * 1000), + }, + }, + } + + @staticmethod + def create_progress_event(ctx: StreamContext, percent: int, message: str) -> StreamingEvent | None: + """Create progress_update event.""" + if not ctx.event_ctx: + return None + + return { + "type": ChatEventType.PROGRESS_UPDATE, + "data": { + "progress_percent": percent, + "message": message, + "context": { + "agent_id": ctx.event_ctx.agent_id, + "agent_name": ctx.event_ctx.agent_name, + "agent_type": ctx.event_ctx.agent_type, + "execution_id": ctx.event_ctx.execution_id, + "depth": 0, + "execution_path": [ctx.event_ctx.agent_name], + "started_at": int(ctx.agent_start_time * 1000), + }, + }, + } diff --git a/service/app/mcp/graph_tools.py b/service/app/mcp/graph_tools.py deleted file mode 100644 index 7ec9a8c5..00000000 --- a/service/app/mcp/graph_tools.py +++ /dev/null @@ -1,1162 +0,0 @@ -# """ -# MCP Server for Graph Agent Tools - Simple AI Agent Creation - -# This module provides tools for creating graph-based AI agents. Use `create_agent_with_graph` -# for the simplest approach, then `inspect_agent` to verify and `run_agent` to test. - -# ## 🚀 RECOMMENDED: CREATE COMPLETE AGENT IN ONE CALL - -# Use `create_agent_with_graph()` - it's the easiest way to create working agents: - -# ```python -# create_agent_with_graph( -# name="Q&A Assistant", -# description="Answers user questions", -# state_schema={ -# "type": "object", -# "properties": { -# "messages": {"type": "array"}, -# "current_step": {"type": "string"}, -# "user_input": {"type": "string"}, -# "final_output": {"type": "string"} -# }, -# "required": ["messages", "current_step"] -# }, -# nodes=[ -# {"name": "start", "node_type": "start", "config": {}}, -# { -# "name": "assistant", -# "node_type": "llm", -# "config": { -# "model": "gpt-5", -# "provider_name": "system", -# "system_prompt": "You are a helpful assistant. Answer questions clearly." -# } -# }, -# {"name": "end", "node_type": "end", "config": {}} -# ], -# edges=[ -# {"from_node": "start", "to_node": "assistant"}, -# {"from_node": "assistant", "to_node": "end"} -# ] -# ) -# ``` - -# Then ALWAYS use `inspect_agent(agent_id)` to verify your agent before running it! - -# ## 🚨 CRITICAL: ALWAYS USE SYSTEM PROVIDER - -# For ALL LLM nodes: `"provider_name": "system"` is REQUIRED! - -# ## 📈 WORKFLOW: CREATE → INSPECT → RUN - -# 1. **Create**: Use `create_agent_with_graph()` (copy the template above) -# 2. **Inspect**: Use `inspect_agent(agent_id)` to verify structure -# 3. **Run**: Use `run_agent(agent_id, input_state)` to test - -# ## 📋 QUICK NODE REFERENCE - -# - **"start"**: Entry point, config: `{}` -# - **"llm"**: AI processing, config: `{"model": "gpt-5", "provider_name": "system", "system_prompt": "..."}` -# - **"tool"**: Function calls, config: `{"tool_name": "function_name"}` -# - **"router"**: Branching logic, config: `{"conditions": [...], "default_target": "node_name"}` -# - **"end"**: Exit point, config: `{}` - -# ## 🔧 ADVANCED: Individual Functions - -# If you need to build agents step-by-step instead of using `create_agent_with_graph`: -# - `create_agent()`: Create empty agent -# - `add_node()`: Add individual nodes -# - `add_edge()`: Connect nodes -# - `define_state()`: Customize state schema - -# ## 🛠️ ESSENTIAL TOOLS - -# - `inspect_agent(agent_id)`: **ALWAYS use this to verify your agent!** -# - `validate_agent_structure(agent_id)`: Check for problems -# - `list_agents()`: See all your agents -# - `run_agent(agent_id, input_state)`: Execute your agent -# """ - -# import json -# import logging -# from typing import Any -# from uuid import UUID - -# from fastmcp import FastMCP -# from fastmcp.server.auth import JWTVerifier, TokenVerifier -# from fastmcp.server.dependencies import AccessToken, get_access_token - -# from app.core.chat.langgraph import execute_graph_agent_sync -# from app.infra.database import AsyncSessionLocal -# from app.middleware.auth import AuthProvider, UserInfo -# from app.middleware.auth import AuthProvider as InternalAuthProvider -# from app.middleware.auth.token_verifier.bohr_app_token_verifier import BohrAppTokenVerifier -# from app.models.graph import ( -# GraphAgentCreate, -# GraphAgentUpdate, -# GraphEdgeCreate, -# GraphNodeCreate, -# ) -# from app.repos.graph import GraphRepository - -# logger = logging.getLogger(__name__) - -# # MCP Server instance -# mcp = FastMCP("graph-tools") - -# # 创建认证提供者 - 使用 TokenVerifier 类型但赋值给变量名 auth -# # 这个变量会被 MCP 自动发现机制识别为 AuthProvider(因为 TokenVerifier 继承自 AuthProvider) -# auth: TokenVerifier - -# match InternalAuthProvider.get_provider_name(): -# case "bohrium": -# auth = JWTVerifier( -# public_key=InternalAuthProvider.public_key, -# ) -# case "casdoor": -# auth = JWTVerifier( -# jwks_uri=InternalAuthProvider.jwks_uri, -# ) -# case "bohr_app": -# auth = BohrAppTokenVerifier( -# api_url=InternalAuthProvider.issuer, -# x_app_key="xyzen-uuid1760783737", -# ) -# case _: -# raise ValueError(f"Unsupported authentication provider: {InternalAuthProvider.get_provider_name()}") - - -# def error_response(message: str) -> str: -# """Helper function to return consistent error responses""" -# return json.dumps( -# { -# "status": "error", -# "message": message, -# }, -# indent=2, -# ) - - -# def get_current_user() -> UserInfo: -# """ -# Dependency function to get the current user from the access token. -# """ -# access_token: AccessToken | None = get_access_token() -# if not access_token: -# raise ValueError("Access token is required for this operation.") - -# user_info = AuthProvider.parse_user_info(access_token.claims) -# if not user_info or not user_info.id: -# raise ValueError(f"Hello, unknown! Your scopes are: {', '.join(access_token.scopes)}") -# return user_info - - -# async def get_node_id_by_name(repo: GraphRepository, agent_id: UUID, node_name: str) -> UUID: -# """Helper to get node ID by name within an agent""" -# nodes = await repo.get_nodes_by_agent(agent_id) -# for node in nodes: -# if node.name == node_name: -# return node.id -# raise ValueError(f"Node '{node_name}' not found in agent {agent_id}") - - -# def get_node_config_template(node_type: str) -> dict[str, Any]: -# """Get a template configuration for a specific node type""" -# templates = { -# "llm": { -# "model": "gpt-5", -# "provider_name": "system", -# "system_prompt": "You are a helpful assistant. Process the input and provide a response.", -# }, -# "tool": {"tool_name": "example_tool", "parameters": {}, "timeout_seconds": 30}, -# "router": { -# "conditions": [{"field": "intent", "operator": "equals", "value": "search", "target": "search_node"}], -# "default_target": "default_node", -# }, -# "subagent": {"agent_id": "sub-agent-uuid", "input_mapping": {}, "output_mapping": {}}, -# "start": {}, -# "end": {"output_format": "json"}, -# } -# return templates.get(node_type, {}) - - -# def success_response(message: str, data: dict[str, Any] | None = None) -> str: -# """Helper function to return consistent success responses""" -# response = { -# "status": "success", -# "message": message, -# } -# if data: -# response.update(data) -# return json.dumps(response, indent=2) - - -# @mcp.tool -# async def create_agent( -# name: str, -# description: str, -# ) -> str: -# """ -# ⚠️ ADVANCED: Create empty agent (requires more steps) - -# Creates an empty agent that you must build manually with add_node() and add_edge(). -# Most users should use create_agent_with_graph() instead for simpler workflow. - -# Args: -# name: Agent name -# description: What the agent does - -# Returns: -# JSON with agent_id for use in add_node() and add_edge() calls - -# 💡 RECOMMENDED: Use create_agent_with_graph() instead for easier agent creation! -# """ -# user_info = get_current_user() - -# try: -# if not name or not description: -# return error_response("Missing required fields: name, description") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Create agent with basic state schema -# agent_data = GraphAgentCreate( -# name=name, -# description=description, -# state_schema={ -# "type": "object", -# "properties": { -# "messages": {"type": "array"}, -# "current_step": {"type": "string"}, -# "user_input": {"type": "string"}, -# "final_output": {"type": "string"}, -# "execution_context": {"type": "object"}, -# }, -# }, -# ) - -# agent = await repo.create_graph_agent(agent_data, user_info.id) -# await session.commit() - -# logger.info(f"Created graph agent: {agent.id}") -# return json.dumps( -# { -# "status": "success", -# "message": f"Graph agent '{name}' created successfully", -# "agent_id": str(agent.id), -# "name": name, -# "description": description, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to create agent: {e}") -# return error_response(f"Error creating agent: {str(e)}") - - -# @mcp.tool -# async def define_state(agent_id: str, state_schema: dict[str, Any]) -> str: -# """ -# ⚠️ ADVANCED: Customize state schema (rarely needed) - -# Updates the data structure that flows between nodes. Most users should skip this -# since create_agent_with_graph() includes a good default schema. - -# Args: -# agent_id: Agent ID from create_agent() -# state_schema: JSON schema object - -# 💡 RECOMMENDED: Use create_agent_with_graph() with default schema instead! -# """ -# user_info = get_current_user() - -# try: -# if not agent_id or not state_schema: -# return error_response("Missing required fields: agent_id, state_schema") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Check agent exists and user has permission -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to modify this agent") - -# update_data = GraphAgentUpdate(state_schema=state_schema) -# updated_agent = await repo.update_graph_agent(UUID(agent_id), update_data) - -# if not updated_agent: -# return error_response(f"Failed to update agent {agent_id}") - -# await session.commit() - -# logger.info(f"Updated state schema for agent: {agent_id}") -# return json.dumps( -# { -# "status": "success", -# "message": f"Successfully updated state schema for agent {agent_id}", -# "agent_id": agent_id, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to define state: {e}") -# return error_response(f"Error defining state: {str(e)}") - - -# @mcp.tool -# async def add_node( -# agent_id: str, -# name: str, -# node_type: str, -# config: dict[str, Any], -# position_x: float | None = None, -# position_y: float | None = None, -# ) -> str: -# """ -# ⚠️ ADVANCED: Add individual nodes (manual method) - -# Adds a single node to an agent created with create_agent(). -# Most users should use create_agent_with_graph() instead for simpler workflow. - -# Args: -# agent_id: Agent ID from create_agent() -# name: Unique node name -# node_type: "start", "llm", "tool", "router", "end" -# config: Node configuration (LLM nodes need "provider_name": "system") - -# 💡 RECOMMENDED: Use create_agent_with_graph() instead for easier agent creation! -# """ -# user_info = get_current_user() - -# try: -# if not agent_id or not name or not node_type: -# return error_response("Missing required fields: agent_id, name, node_type") - -# # Validate node type -# valid_types = ["llm", "tool", "router", "subagent", "start", "end"] -# if node_type not in valid_types: -# return error_response(f"Invalid node type '{node_type}'. Valid types: {valid_types}") - -# # Validate provider_name for LLM nodes -# if node_type == "llm" and config.get("provider_name"): -# from app.core.providers import get_user_provider_manager - -# async with AsyncSessionLocal() as temp_session: -# try: -# user_provider_manager = await get_user_provider_manager(user_info.id, temp_session) -# provider = user_provider_manager.get_provider_config(config["provider_name"]) -# if not provider: -# return error_response( -# f"Provider '{config['provider_name']}' not found or not available to user" -# ) -# except Exception as e: -# logger.warning(f"Could not validate provider '{config.get('provider_name')}': {e}") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Check agent exists and user has permission -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to modify this agent") - -# node_data = GraphNodeCreate( -# name=name, -# node_type=node_type, -# config=config, -# graph_agent_id=UUID(agent_id), -# position_x=position_x, -# position_y=position_y, -# ) - -# node = await repo.create_node(node_data) -# await session.commit() - -# logger.info(f"Added node '{name}' to agent {agent_id}") -# return json.dumps( -# { -# "status": "success", -# "message": f"Successfully added {node_type} node '{name}'", -# "node_id": str(node.id), -# "agent_id": agent_id, -# "name": name, -# "node_type": node_type, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to add node: {e}") -# return error_response(f"Error adding node: {str(e)}") - - -# @mcp.tool -# async def add_edge( -# agent_id: str, -# from_node: str, -# to_node: str, -# condition: dict[str, Any] | None = None, -# label: str | None = None, -# ) -> str: -# """ -# ⚠️ ADVANCED: Connect nodes manually - -# Connects nodes created with add_node(). Node names must match exactly. -# Most users should use create_agent_with_graph() instead for simpler workflow. - -# Args: -# agent_id: Agent ID from create_agent() -# from_node: Source node name (must exist) -# to_node: Target node name (must exist) - -# 💡 RECOMMENDED: Use create_agent_with_graph() instead for easier agent creation! -# """ -# user_info = get_current_user() - -# try: -# if not agent_id or not from_node or not to_node: -# return error_response("Missing required fields: agent_id, from_node, to_node") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Check agent exists and user has permission -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to modify this agent") - -# agent_uuid = UUID(agent_id) - -# # Get node IDs by names -# from_node_id = await get_node_id_by_name(repo, agent_uuid, from_node) -# to_node_id = await get_node_id_by_name(repo, agent_uuid, to_node) - -# edge_data = GraphEdgeCreate( -# from_node_id=from_node_id, -# to_node_id=to_node_id, -# condition=condition, -# graph_agent_id=agent_uuid, -# label=label, -# ) - -# edge = await repo.create_edge(edge_data) -# await session.commit() - -# logger.info(f"Added edge from '{from_node}' to '{to_node}' in agent {agent_id}") -# return json.dumps( -# { -# "status": "success", -# "message": f"Successfully added edge from '{from_node}' to '{to_node}'", -# "edge_id": str(edge.id), -# "agent_id": agent_id, -# "from_node": from_node, -# "to_node": to_node, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to add edge: {e}") -# return error_response(f"Error adding edge: {str(e)}") - - -# @mcp.tool -# async def run_agent(agent_id: str, input_state: dict[str, Any]) -> str: -# """ -# 🚀 ESSENTIAL: Execute your agent - -# Runs your agent with the provided input. Use this to test and interact with your agent. - -# Args: -# agent_id: Agent ID from create_agent_with_graph() -# input_state: Input data - MUST include these fields: -# {"user_input": "question", "messages": [], "current_step": "start"} - -# Returns: -# JSON with execution results and the agent's response - -# ✅ EXAMPLE: -# run_agent( -# agent_id="your-agent-id", -# input_state={ -# "user_input": "Hello, how are you?", -# "messages": [], -# "current_step": "start" -# } -# ) - -# 💡 TIP: Use inspect_agent() first to verify your agent is properly structured! -# """ -# user_info = get_current_user() - -# try: -# if not agent_id or not input_state: -# return error_response("Missing required fields: agent_id, input_state") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Check agent exists and user has permission -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to execute this agent") - -# # Add user_id to input state for execution context -# enhanced_input_state = { -# **input_state, -# "execution_context": {**input_state.get("execution_context", {}), "user_id": user_info.id}, -# } - -# # Execute graph agent synchronously -# result = await execute_graph_agent_sync(session, UUID(agent_id), enhanced_input_state, user_info.id) - -# if result.success: -# return json.dumps( -# { -# "status": "success", -# "message": f"Agent executed successfully in {result.execution_time_ms}ms", -# "agent_id": agent_id, -# "final_state": result.final_state, -# "execution_time_ms": result.execution_time_ms, -# }, -# indent=2, -# ) -# else: -# return json.dumps( -# { -# "status": "error", -# "message": result.error_message or "Agent execution failed", -# "agent_id": agent_id, -# "execution_time_ms": result.execution_time_ms, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to run agent: {e}") -# return error_response(f"Error running agent: {str(e)}") - - -# @mcp.tool -# async def list_agents() -> str: -# """ -# List all graph agents for the current user. - -# Returns: -# JSON string containing list of agents -# """ -# user_info = get_current_user() - -# try: -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# agents = await repo.get_graph_agents_by_user(user_info.id) - -# if not agents: -# return json.dumps( -# { -# "status": "success", -# "message": "No graph agents found for current user", -# "agents": [], -# "count": 0, -# }, -# indent=2, -# ) - -# agent_list = [] -# for agent in agents: -# agent_info = { -# "id": str(agent.id), -# "name": agent.name, -# "description": agent.description, -# "is_active": agent.is_active, -# "created_at": agent.created_at.isoformat(), -# "updated_at": agent.updated_at.isoformat(), -# } -# agent_list.append(agent_info) - -# return json.dumps( -# { -# "status": "success", -# "agents": agent_list, -# "count": len(agents), -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to list agents: {e}") -# return error_response(f"Error listing agents: {str(e)}") - - -# @mcp.tool -# async def create_agent_with_graph( -# name: str, -# description: str, -# state_schema: dict[str, Any], -# nodes: list[dict[str, Any]], -# edges: list[dict[str, Any]], -# ) -> str: -# """ -# 🚀 RECOMMENDED: Create a complete working agent in one call - -# This is the easiest and most reliable way to create graph agents. -# Copy the template below and modify it for your needs. - -# Args: -# name: Short name for your agent -# description: What the agent does -# state_schema: Use the template below (copy exactly) -# nodes: List of nodes (start, processing nodes, end) -# edges: List of connections between nodes - -# Returns: -# JSON with agent_id and creation confirmation - -# ✅ COPY THIS TEMPLATE (works every time): - -# create_agent_with_graph( -# name="Your Agent Name", -# description="What your agent does", -# state_schema={ -# "type": "object", -# "properties": { -# "messages": {"type": "array"}, -# "current_step": {"type": "string"}, -# "user_input": {"type": "string"}, -# "final_output": {"type": "string"} -# }, -# "required": ["messages", "current_step"] -# }, -# nodes=[ -# {"name": "start", "node_type": "start", "config": {}}, -# { -# "name": "assistant", -# "node_type": "llm", -# "config": { -# "model": "gpt-5", -# "provider_name": "system", -# "system_prompt": "Your custom prompt here" -# } -# }, -# {"name": "end", "node_type": "end", "config": {}} -# ], -# edges=[ -# {"from_node": "start", "to_node": "assistant"}, -# {"from_node": "assistant", "to_node": "end"} -# ] -# ) - -# 🔧 CUSTOMIZE: -# - Change the agent name and description -# - Modify the system_prompt for your use case -# - Add more nodes between "start" and "end" if needed -# - Connect new nodes with additional edges - -# ⚠️ CRITICAL: Always include "provider_name": "system" for LLM nodes! - -# 📋 NEXT STEPS: -# 1. Run this function to get agent_id -# 2. Use inspect_agent(agent_id) to verify -# 3. Use run_agent(agent_id, input_state) to test -# """ -# user_info = get_current_user() - -# try: -# if not name or not description or not state_schema or not nodes: -# return error_response("Missing required fields: name, description, state_schema, nodes") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Create agent first -# agent_data = GraphAgentCreate( -# name=name, -# description=description, -# state_schema=state_schema, -# ) -# agent = await repo.create_graph_agent(agent_data, user_info.id) - -# # Create nodes and build name-to-ID mapping -# node_id_map = {} -# for node_data in nodes: -# node_create = GraphNodeCreate( -# name=node_data["name"], -# node_type=node_data["node_type"], -# config=node_data.get("config", {}), -# graph_agent_id=agent.id, -# position_x=node_data.get("position_x"), -# position_y=node_data.get("position_y"), -# ) -# node = await repo.create_node(node_create) -# node_id_map[node.name] = node.id - -# # Create edges with resolved node IDs -# edges_created = 0 -# for edge_data in edges: -# from_name = edge_data["from_node"] -# to_name = edge_data["to_node"] - -# if from_name not in node_id_map or to_name not in node_id_map: -# logger.warning(f"Skipping edge from {from_name} to {to_name}: nodes not found") -# continue - -# edge_create = GraphEdgeCreate( -# from_node_id=node_id_map[from_name], -# to_node_id=node_id_map[to_name], -# condition=edge_data.get("condition"), -# graph_agent_id=agent.id, -# label=edge_data.get("label"), -# ) -# await repo.create_edge(edge_create) -# edges_created += 1 - -# await session.commit() - -# logger.info(f"Created complete graph agent: {agent.id}") -# return json.dumps( -# { -# "status": "success", -# "message": ( -# f"Successfully created graph agent '{name}' with {len(nodes)} nodes and {edges_created} edges" -# ), -# "agent_id": str(agent.id), -# "name": name, -# "description": description, -# "nodes_created": len(nodes), -# "edges_created": edges_created, -# }, -# indent=2, -# ) - -# except Exception as e: -# logger.error(f"Failed to create agent with graph: {e}") -# return error_response(f"Error creating agent with graph: {str(e)}") - - -# @mcp.tool -# async def inspect_agent(agent_id: str) -> str: -# """ -# 🔍 ESSENTIAL: View your agent structure (ALWAYS use this!) - -# This shows you exactly what your agent looks like - its nodes, connections, -# and whether it's properly structured. Use this after creating any agent! - -# Args: -# agent_id: The agent_id from create_agent_with_graph() result - -# Returns: -# Complete agent information including: -# - All nodes and their configurations -# - All connections (edges) -# - Validation status (errors/warnings) -# - Structure overview - -# 💡 USE THIS TO: -# - Verify your agent was created correctly -# - Debug connection issues -# - Check node configurations -# - Confirm the agent is ready to run - -# Example: -# inspect_agent(agent_id="your-agent-id-here") -# """ -# user_info = get_current_user() - -# try: -# if not agent_id: -# return error_response("Missing required field: agent_id") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Get agent details -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to inspect this agent") - -# # Get nodes and edges -# nodes = await repo.get_nodes_by_agent(UUID(agent_id)) -# edges = await repo.get_edges_by_agent(UUID(agent_id)) - -# # Build node details -# node_details = [] -# node_name_map = {} -# for node in nodes: -# node_info = { -# "id": str(node.id), -# "name": node.name, -# "type": node.node_type, -# "config": node.config, -# "position": {"x": node.position_x, "y": node.position_y}, -# } -# node_details.append(node_info) -# node_name_map[node.id] = node.name - -# # Build edge details -# edge_details = [] -# for edge in edges: -# edge_info = { -# "id": str(edge.id), -# "from_node": node_name_map.get(edge.from_node_id, "UNKNOWN"), -# "to_node": node_name_map.get(edge.to_node_id, "UNKNOWN"), -# "condition": edge.condition, -# "label": edge.label, -# } -# edge_details.append(edge_info) - -# # Graph statistics and validation -# node_types = {} -# for node in nodes: -# node_types[node.node_type] = node_types.get(node.node_type, 0) + 1 - -# has_start_node = any(node.node_type == "start" for node in nodes) -# has_end_node = any(node.node_type == "end" for node in nodes) - -# graph_validation = { -# "has_start_node": has_start_node, -# "has_end_node": has_end_node, -# "is_complete": has_start_node and has_end_node, -# "total_nodes": len(nodes), -# "total_edges": len(edges), -# "node_type_counts": node_types, -# } - -# return success_response( -# f"Agent '{agent.name}' inspection complete", -# { -# "agent": { -# "id": str(agent.id), -# "name": agent.name, -# "description": agent.description, -# "state_schema": agent.state_schema, -# "is_active": agent.is_active, -# "created_at": agent.created_at.isoformat(), -# "updated_at": agent.updated_at.isoformat(), -# }, -# "nodes": node_details, -# "edges": edge_details, -# "graph_validation": graph_validation, -# }, -# ) - -# except Exception as e: -# logger.error(f"Failed to inspect agent: {e}") -# return error_response(f"Error inspecting agent: {str(e)}") - - -# @mcp.tool -# async def get_node_template(node_type: str) -> str: -# """ -# Get a configuration template for a specific node type. - -# This tool provides ready-to-use configuration templates for each node type, -# which can be used as starting points when creating nodes. - -# Args: -# node_type: Type of node ('llm', 'tool', 'router', 'subagent', 'start', 'end') - -# Returns: -# JSON string with template configuration and usage guidance - -# Example Usage: -# get_node_template(node_type="llm") -# get_node_template(node_type="router") -# """ -# try: -# valid_types = ["llm", "tool", "router", "subagent", "start", "end"] -# if node_type not in valid_types: -# return error_response(f"Invalid node type '{node_type}'. Valid types: {valid_types}") - -# template = get_node_config_template(node_type) - -# return success_response( -# f"Configuration template for {node_type} node", -# { -# "node_type": node_type, -# "template": template, -# "usage_example": f"""add_node( -# agent_id="your-agent-id", -# name="your_node_name", -# node_type="{node_type}", -# config={json.dumps(template, indent=8)} -# )""", -# }, -# ) - -# except Exception as e: -# logger.error(f"Failed to get node template: {e}") -# return error_response(f"Error getting node template: {str(e)}") - - -# @mcp.tool -# async def list_user_providers() -> str: -# """ -# List available AI providers for the current user. - -# This tool shows all providers (both system and user-specific) that can be -# used in the provider_name field when creating LLM nodes. - -# Returns: -# JSON string with list of available providers including: -# - Provider names that can be used in LLM node configurations -# - Provider types (OpenAI, Anthropic, etc.) -# - Whether each provider is currently active -# - Provider availability status - -# Example Usage: -# list_user_providers() - -# Use the returned provider names in LLM node configurations: -# config = { -# "model": "gpt-5", -# "provider_name": "system", # Use a name from this list -# "system_prompt": "..." -# } -# """ -# user_info = get_current_user() - -# try: -# from app.core.providers import get_user_provider_manager - -# async with AsyncSessionLocal() as session: -# user_provider_manager = await get_user_provider_manager(user_info.id, session) - -# # Get list of providers -# providers_info = user_provider_manager.list_providers() - -# return success_response( -# f"Found {len(providers_info)} available providers for user", -# { -# "providers": providers_info, -# "count": len(providers_info), -# "usage_note": "Use the 'name' field values in LLM node 'provider_name' configuration", -# }, -# ) - -# except Exception as e: -# logger.error(f"Failed to list user providers: {e}") -# return error_response(f"Error listing user providers: {str(e)}") - - -# @mcp.tool -# async def delete_agent(agent_id: str) -> str: -# """ -# 🗑️ DELETE: Remove an agent permanently - -# Permanently deletes a graph agent and all its nodes and edges. This action cannot be undone! - -# Args: -# agent_id: The agent_id from create_agent_with_graph() or list_agents() - -# Returns: -# JSON confirmation of deletion - -# ⚠️ WARNING: This action is PERMANENT and cannot be undone! - -# 💡 SAFETY TIP: Use inspect_agent() first to verify you're deleting the correct agent - -# Example: -# delete_agent(agent_id="12345678-1234-1234-1234-123456789abc") -# """ -# user_info = get_current_user() - -# try: -# if not agent_id: -# return error_response("Missing required field: agent_id") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Check agent exists and user has permission -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to delete this agent") - -# # Get agent details for confirmation message -# agent_name = agent.name - -# # Get counts for confirmation -# nodes = await repo.get_nodes_by_agent(UUID(agent_id)) -# edges = await repo.get_edges_by_agent(UUID(agent_id)) -# node_count = len(nodes) -# edge_count = len(edges) - -# # Delete the agent (this should cascade to delete nodes and edges) -# success = await repo.delete_graph_agent(UUID(agent_id)) - -# if not success: -# return error_response(f"Failed to delete agent {agent_id}") - -# await session.commit() - -# logger.info( -# f"Deleted graph agent: {agent_id} ('{agent_name}') with {node_count} nodes and {edge_count} edges" -# ) -# return success_response( -# f"Successfully deleted agent '{agent_name}' and all its components", -# { -# "agent_id": agent_id, -# "agent_name": agent_name, -# "nodes_deleted": node_count, -# "edges_deleted": edge_count, -# "deletion_time": "permanent", -# }, -# ) - -# except Exception as e: -# logger.error(f"Failed to delete agent: {e}") -# return error_response(f"Error deleting agent: {str(e)}") - - -# @mcp.tool -# async def validate_agent_structure(agent_id: str) -> str: -# """ -# Validate the structure and configuration of a graph agent. - -# This tool performs comprehensive validation checks on an agent's structure, -# including node configurations, connectivity, and graph completeness. - -# Args: -# agent_id: UUID of the graph agent to validate - -# Returns: -# JSON string with validation results and recommendations - -# Example Usage: -# validate_agent_structure(agent_id="12345678-1234-1234-1234-123456789abc") -# """ -# user_info = get_current_user() - -# try: -# if not agent_id: -# return error_response("Missing required field: agent_id") - -# async with AsyncSessionLocal() as session: -# repo = GraphRepository(session) - -# # Get agent and check permissions -# agent = await repo.get_graph_agent_by_id(UUID(agent_id)) -# if not agent: -# return error_response(f"Agent {agent_id} not found") - -# if agent.user_id != user_info.id: -# return error_response("Permission denied: You don't have permission to validate this agent") - -# # Get nodes and edges -# nodes = await repo.get_nodes_by_agent(UUID(agent_id)) -# edges = await repo.get_edges_by_agent(UUID(agent_id)) - -# # Validation results -# validation_results = {"is_valid": True, "errors": [], "warnings": [], "recommendations": []} - -# # Node validation -# node_names = set() -# node_types = {} -# for node in nodes: -# # Check for duplicate names -# if node.name in node_names: -# validation_results["errors"].append(f"Duplicate node name: '{node.name}'") -# validation_results["is_valid"] = False -# node_names.add(node.name) - -# # Count node types -# node_types[node.node_type] = node_types.get(node.node_type, 0) + 1 - -# # Graph structure validation -# if "start" not in node_types: -# validation_results["errors"].append("Graph must have at least one 'start' node") -# validation_results["is_valid"] = False -# elif node_types["start"] > 1: -# validation_results["warnings"].append( -# f"Graph has {node_types['start']} start nodes - consider using only one" -# ) - -# if "end" not in node_types: -# validation_results["warnings"].append("Graph should have at least one 'end' node") - -# # Edge validation -# node_id_to_name = {node.id: node.name for node in nodes} -# connected_nodes = set() - -# for edge in edges: -# from_name = node_id_to_name.get(edge.from_node_id) -# to_name = node_id_to_name.get(edge.to_node_id) - -# if not from_name: -# validation_results["errors"].append( -# f"Edge references non-existent from_node ID: {edge.from_node_id}" -# ) -# validation_results["is_valid"] = False -# if not to_name: -# validation_results["errors"].append(f"Edge references non-existent to_node ID: {edge.to_node_id}") -# validation_results["is_valid"] = False - -# if from_name and to_name: -# connected_nodes.add(from_name) -# connected_nodes.add(to_name) - -# # Check for isolated nodes -# for node in nodes: -# if node.name not in connected_nodes and node.node_type not in ["start", "end"]: -# validation_results["warnings"].append(f"Node '{node.name}' is not connected to any other nodes") - -# # Recommendations -# if len(nodes) == 0: -# validation_results["recommendations"].append( -# "Start by adding a 'start' node to begin building your graph" -# ) -# elif len(edges) == 0 and len(nodes) > 1: -# validation_results["recommendations"].append( -# "Add edges to connect your nodes and define execution flow" -# ) - -# if "router" in node_types and node_types["router"] > 0: -# validation_results["recommendations"].append( -# "Ensure router nodes have proper conditions defined for all possible paths" -# ) - -# return success_response( -# f"Validation complete for agent '{agent.name}'", -# { -# "agent_id": agent_id, -# "validation": validation_results, -# "statistics": { -# "total_nodes": len(nodes), -# "total_edges": len(edges), -# "node_type_counts": node_types, -# "connected_nodes": len(connected_nodes), -# }, -# }, -# ) - -# except Exception as e: -# logger.error(f"Failed to validate agent structure: {e}") -# return error_response(f"Error validating agent structure: {str(e)}") - - -# __all__ = ["mcp"] diff --git a/service/app/repos/message.py b/service/app/repos/message.py index f8350e74..63e0ffb0 100644 --- a/service/app/repos/message.py +++ b/service/app/repos/message.py @@ -396,6 +396,7 @@ async def get_messages_with_files_and_citations( attachments=file_reads_with_urls, citations=citations, thinking_content=message.thinking_content, + agent_metadata=message.agent_metadata, ) messages_with_files_and_citations.append(message_with_files_and_citations) diff --git a/service/app/schemas/graph_config.py b/service/app/schemas/graph_config.py index 76f39add..752f478f 100644 --- a/service/app/schemas/graph_config.py +++ b/service/app/schemas/graph_config.py @@ -81,6 +81,26 @@ class GraphStateSchema(BaseModel): # --- Node Configuration Definitions --- +class StructuredOutputField(BaseModel): + """Definition of a field in structured output schema.""" + + type: str = Field(description="Field type: 'string', 'bool', 'int', 'float', 'list', 'dict'") + description: str = Field(default="", description="Field description for LLM guidance") + default: Any = Field(default=None, description="Default value if not provided") + required: bool = Field(default=True, description="Whether the field is required") + + +class StructuredOutputSchema(BaseModel): + """JSON-based schema for structured LLM output. + + This allows defining output structure directly in JSON config, + which is then converted to a Pydantic model at runtime. + """ + + fields: dict[str, StructuredOutputField] = Field(description="Field definitions for the structured output") + description: str = Field(default="", description="Description of what this output represents") + + class LLMNodeConfig(BaseModel): """Configuration for LLM reasoning nodes.""" @@ -94,6 +114,24 @@ class LLMNodeConfig(BaseModel): max_iterations: int = Field(default=10, ge=1, description="Maximum iterations for ReAct-style tool loops") stop_sequences: list[str] | None = Field(default=None, description="Stop sequences for generation") + # Structured output configuration + structured_output: StructuredOutputSchema | None = Field( + default=None, + description="Schema for structured JSON output. When set, LLM response is parsed into fields.", + ) + message_key: str | None = Field( + default=None, + description="Field from structured output to use as user-facing message (prevents raw JSON display).", + ) + message_key_condition: dict[str, str] | None = Field( + default=None, + description=( + "Conditional message field selection. Format: {'condition_field': 'bool_field', " + "'true_key': 'field_if_true', 'false_key': 'field_if_false'}. " + "Example: {'condition_field': 'need_clarification', 'true_key': 'question', 'false_key': 'verification'}" + ), + ) + class ToolNodeConfig(BaseModel): """Configuration for tool execution nodes.""" diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index 965a87e2..02094fb6 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -17,8 +17,8 @@ from app.models.citation import CitationCreate from app.models.message import Message, MessageCreate from app.repos import CitationRepository, FileRepository, MessageRepository, TopicRepository -from app.schemas.chat_event_types import CitationData -from app.schemas.chat_events import ChatEventType +from app.schemas.chat_event_payloads import CitationData +from app.schemas.chat_event_types import ChatEventType logger = logging.getLogger(__name__) @@ -213,6 +213,29 @@ async def _process_chat_message_async( elif stream_event["type"] == ChatEventType.STREAMING_END: full_content = stream_event["data"].get("content", full_content) + # Extract agent_state for persistence to message agent_metadata + agent_state_data = stream_event["data"].get("agent_state") + + # For graph-based agents, use final node output as message content + # instead of concatenated content from all nodes + if agent_state_data and "node_outputs" in agent_state_data: + node_outputs = agent_state_data["node_outputs"] + # Priority: final_report_generation > agent > model > fallback to streamed + final_content = ( + node_outputs.get("final_report_generation") + or node_outputs.get("agent") + or node_outputs.get("model") + ) + if final_content: + if isinstance(final_content, str): + full_content = final_content + elif isinstance(final_content, dict): + # Handle structured output - extract text content + full_content = final_content.get("content", str(final_content)) + + if agent_state_data and ai_message_obj: + ai_message_obj.agent_metadata = agent_state_data + db.add(ai_message_obj) await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.TOKEN_USAGE: diff --git a/web/src/components/layouts/components/AgentExecutionBubble.tsx b/web/src/components/layouts/components/AgentExecutionBubble.tsx new file mode 100644 index 00000000..8184cec3 --- /dev/null +++ b/web/src/components/layouts/components/AgentExecutionBubble.tsx @@ -0,0 +1,295 @@ +import { AnimatePresence, motion } from "framer-motion"; +import { + ChevronDown, + ChevronRight, + FlaskConical, + CheckCircle2, + XCircle, + Loader2, + Clock, +} from "lucide-react"; +import { useState } from "react"; +import { useTranslation } from "react-i18next"; +import type { AgentExecutionState } from "@/types/agentEvents"; +import AgentPhaseItem from "./AgentPhaseItem"; +import AgentProgressBar from "./AgentProgressBar"; + +interface AgentExecutionBubbleProps { + execution: AgentExecutionState; + isExecuting: boolean; +} + +/** + * AgentExecutionBubble displays agent execution progress and history. + * + * Two states: + * 1. Active execution (isExecuting=true): Animated view showing progress and current phase + * 2. Collapsed (isExecuting=false): Expandable accordion to view execution timeline + */ +export default function AgentExecutionBubble({ + execution, + isExecuting, +}: AgentExecutionBubbleProps) { + const { t } = useTranslation(); + const [isExpanded, setIsExpanded] = useState(false); + + // Format duration in human-readable format + const formatDuration = (ms?: number): string => { + if (ms === undefined) return ""; + if (ms < 1000) return `${ms}ms`; + const seconds = ms / 1000; + if (seconds < 60) return `${seconds.toFixed(1)}s`; + const minutes = Math.floor(seconds / 60); + const remainingSeconds = (seconds % 60).toFixed(0); + return `${minutes}m ${remainingSeconds}s`; + }; + + // Get status icon + const getStatusIcon = () => { + switch (execution.status) { + case "running": + return ( + + + + ); + case "completed": + return ; + case "failed": + return ; + case "cancelled": + return ; + default: + return null; + } + }; + + // Get status badge color + const getStatusBadgeClass = () => { + switch (execution.status) { + case "running": + return "bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300"; + case "completed": + return "bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-300"; + case "failed": + return "bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300"; + case "cancelled": + return "bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-300"; + default: + return "bg-neutral-100 text-neutral-700 dark:bg-neutral-900/30 dark:text-neutral-300"; + } + }; + + return ( +
+ + {isExecuting ? ( + // Active execution state - animated progress view + + {/* Subtle shimmer effect */} + + + {/* Header with animated icon */} +
+ + + + + {execution.agentName} + + + {t(`app.chat.agent.status.${execution.status}`, { + defaultValue: + execution.status.charAt(0).toUpperCase() + + execution.status.slice(1), + })} + +
+ + {/* Progress bar */} +
+ +
+ + {/* Current phase/progress message */} +
+ + + + {execution.progressMessage || + execution.currentNode || + execution.currentPhase || + t("app.chat.agent.initializing", { + defaultValue: "Initializing...", + })} + + + + {/* Iteration indicator */} + {execution.iteration && ( +
+ {t("app.chat.agent.iteration", { + current: execution.iteration.current, + max: execution.iteration.max, + defaultValue: `Iteration ${execution.iteration.current}/${execution.iteration.max}`, + })} +
+ )} +
+
+ ) : ( + // Collapsed state - expandable accordion + + {/* Collapsible header */} + + + {/* Expanded content - Phase timeline */} + + {isExpanded && ( + +
+ {/* Phase list */} +
+ {execution.phases.map((phase) => ( + + ))} +
+ + {/* Subagent executions */} + {execution.subagents.length > 0 && ( +
+
+ {t("app.chat.agent.subagents", { + defaultValue: "Subagents", + })} +
+ {execution.subagents.map((subagent) => ( +
+ + └─ + + {subagent.status === "running" ? ( + + ) : subagent.status === "completed" ? ( + + ) : ( + + )} + + {subagent.name} + + {subagent.durationMs && ( + + {formatDuration(subagent.durationMs)} + + )} +
+ ))} +
+ )} + + {/* Error display */} + {execution.error && ( +
+
+ {execution.error.type} +
+
+ {execution.error.message} +
+
+ )} +
+
+ )} +
+
+ )} +
+
+ ); +} diff --git a/web/src/components/layouts/components/AgentExecutionTimeline.tsx b/web/src/components/layouts/components/AgentExecutionTimeline.tsx new file mode 100644 index 00000000..e49676e0 --- /dev/null +++ b/web/src/components/layouts/components/AgentExecutionTimeline.tsx @@ -0,0 +1,324 @@ +import { motion } from "framer-motion"; +import { + FlaskConical, + CheckCircle2, + XCircle, + Loader2, + Clock, +} from "lucide-react"; +import { useTranslation } from "react-i18next"; +import type { AgentExecutionState } from "@/types/agentEvents"; +import AgentNodeItem from "./AgentNodeItem"; +import { useEffect, useRef } from "react"; + +interface AgentExecutionTimelineProps { + execution: AgentExecutionState; + isExecuting: boolean; +} + +/** + * AgentExecutionTimeline displays agent execution as a vertical timeline of phase cards. + * Replaces AgentExecutionBubble with a more detailed, real-time view of execution progress. + */ +export default function AgentExecutionTimeline({ + execution, + isExecuting, +}: AgentExecutionTimelineProps) { + const { t } = useTranslation(); + const activePhaseRef = useRef(null); + + // Auto-scroll to active phase during execution + useEffect(() => { + if (isExecuting && activePhaseRef.current) { + activePhaseRef.current.scrollIntoView({ + behavior: "smooth", + block: "nearest", + }); + } + }, [execution.currentPhase, isExecuting]); + + // Format duration in human-readable format + const formatDuration = (ms?: number): string => { + if (ms === undefined) return ""; + if (ms < 1000) return `${ms}ms`; + const seconds = ms / 1000; + if (seconds < 60) return `${seconds.toFixed(1)}s`; + const minutes = Math.floor(seconds / 60); + const remainingSeconds = (seconds % 60).toFixed(0); + return `${minutes}m ${remainingSeconds}s`; + }; + + // Get overall status icon + const getStatusIcon = () => { + switch (execution.status) { + case "running": + return ( + + + + ); + case "completed": + return ; + case "failed": + return ; + case "cancelled": + return ; + default: + return null; + } + }; + + // Get status badge color + const getStatusBadgeClass = () => { + switch (execution.status) { + case "running": + return "bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300"; + case "completed": + return "bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-300"; + case "failed": + return "bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300"; + case "cancelled": + return "bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-300"; + default: + return "bg-neutral-100 text-neutral-700 dark:bg-neutral-900/30 dark:text-neutral-300"; + } + }; + + // Calculate completion percentage + const completionPercent = (() => { + if (execution.progressPercent !== undefined) { + return execution.progressPercent; + } + const completedPhases = execution.phases.filter( + (p) => p.status === "completed" || p.status === "skipped", + ).length; + return execution.phases.length > 0 + ? Math.round((completedPhases / execution.phases.length) * 100) + : 0; + })(); + + const completedCount = execution.phases.filter( + (p) => p.status === "completed" || p.status === "skipped", + ).length; + + return ( +
+ {/* Overall Header */} + +
+ {/* Agent Icon */} + + + + + {/* Agent Name */} + + {execution.agentName} + + + {/* Status Icon */} + {getStatusIcon()} + + {/* Status Badge */} + + {t(`app.chat.agent.status.${execution.status}`, { + defaultValue: + execution.status.charAt(0).toUpperCase() + + execution.status.slice(1), + })} + + + {/* Duration */} + {execution.durationMs && ( + + {formatDuration(execution.durationMs)} + + )} +
+ + {/* Progress Bar */} + {execution.phases.length > 0 && ( +
+
+ + {/* Shimmer effect for running */} + {isExecuting && ( + + )} + +
+
+ + {completedCount}/{execution.phases.length} phases + + {completionPercent}% +
+
+ )} + + {/* Current Status Message */} + {isExecuting && ( +
+ + + + {execution.progressMessage || + execution.currentNode || + execution.currentPhase || + t("app.chat.agent.initializing", { + defaultValue: "Initializing...", + })} + + + + {/* Iteration indicator */} + {execution.iteration && ( +
+ {t("app.chat.agent.iteration", { + current: execution.iteration.current, + max: execution.iteration.max, + defaultValue: `Iteration ${execution.iteration.current}/${execution.iteration.max}`, + })} +
+ )} +
+ )} +
+ + {/* Phase Timeline - Compact style matching ToolCallCard */} + {execution.phases.length > 0 && ( +
+ {execution.phases.map((phase, index) => { + const isActive = phase.status === "running"; + const isFinalPhase = index === execution.phases.length - 1; + + // Show content for: + // 1. Active phase (currently streaming) + // 2. Any non-final phase (their content is always shown) + // Final phase content is shown below timeline after completion + const shouldShowContent = isActive || !isFinalPhase; + + return ( +
+ +
+ ); + })} +
+ )} + + {/* Subagents */} + {execution.subagents.length > 0 && ( +
+
+ {t("app.chat.agent.subagents", { + defaultValue: "Subagents", + })} +
+ {execution.subagents.map((subagent) => ( +
+ └─ + {subagent.status === "running" ? ( + + ) : subagent.status === "completed" ? ( + + ) : ( + + )} + + {subagent.name} + + {subagent.durationMs && ( + + {formatDuration(subagent.durationMs)} + + )} +
+ ))} +
+ )} + + {/* Error Display */} + {execution.error && ( + +
+ +
+
+ {execution.error.type} +
+
+ {execution.error.message} +
+ {execution.error.nodeId && ( +
+ Node: {execution.error.nodeId} +
+ )} +
+
+
+ )} +
+ ); +} diff --git a/web/src/components/layouts/components/AgentNodeItem.tsx b/web/src/components/layouts/components/AgentNodeItem.tsx new file mode 100644 index 00000000..57a23af1 --- /dev/null +++ b/web/src/components/layouts/components/AgentNodeItem.tsx @@ -0,0 +1,182 @@ +import { AnimatePresence, motion } from "framer-motion"; +import { + ChevronDownIcon, + CheckIcon, + ExclamationTriangleIcon, +} from "@heroicons/react/24/solid"; +import { useState, useEffect } from "react"; +import type { ExecutionStatus } from "@/types/agentEvents"; +import Markdown from "@/lib/Markdown"; +import LoadingMessage from "./LoadingMessage"; + +interface AgentNodeItemProps { + nodeName: string; + status: ExecutionStatus; + content?: string; + isActive?: boolean; + className?: string; +} + +/** + * AgentNodeItem displays a single node/phase in a compact tool-call style. + * Designed to match ToolCallCard appearance with single-line header and expandable content. + */ +export default function AgentNodeItem({ + nodeName, + status, + content, + isActive = false, + className = "", +}: AgentNodeItemProps) { + const [isExpanded, setIsExpanded] = useState(isActive); + + // Auto-expand when node becomes active + useEffect(() => { + if (isActive && !isExpanded) { + setIsExpanded(true); + } + }, [isActive, isExpanded]); + + // Check if there's meaningful content to show + const hasContent = content && content.trim().length > 0; + const canExpand = hasContent && status !== "pending"; + + // Get status indicator matching ToolCallCard style + const getStatusIndicator = () => { + switch (status) { + case "pending": + return ( + + ); + case "running": + return ; + case "completed": + return ( + + + + ); + case "failed": + return ( + + + + ); + case "skipped": + return ( + + ); + default: + return null; + } + }; + + // Get right side indicator + const getRightIndicator = () => { + if (status === "pending") { + return ( + + pending + + ); + } + + if (status === "running" && !hasContent) { + return ( + + running... + + ); + } + + if (!hasContent && status !== "running") { + return ( + + no content + + ); + } + + // Has content - show expand toggle + return ( + + + + ); + }; + + return ( + + {/* Header - Single line, matching ToolCallCard */} +
setIsExpanded(!isExpanded) : undefined} + > +
+ {getStatusIndicator()} + + {nodeName} + +
+ + {getRightIndicator()} +
+ + {/* Streaming content - shown inline when active */} + {isActive && hasContent && ( + +
+ +
+
+ )} + + {/* Expandable content - for completed phases */} + + {isExpanded && !isActive && hasContent && ( + +
+
+ +
+
+
+ )} +
+
+ ); +} diff --git a/web/src/components/layouts/components/AgentPhaseCard.tsx b/web/src/components/layouts/components/AgentPhaseCard.tsx new file mode 100644 index 00000000..99d25cf9 --- /dev/null +++ b/web/src/components/layouts/components/AgentPhaseCard.tsx @@ -0,0 +1,339 @@ +import { AnimatePresence, motion } from "framer-motion"; +import { + ChevronDown, + ChevronRight, + CheckCircle2, + XCircle, + Loader2, + Clock, + Pause, +} from "lucide-react"; +import { useState, useEffect } from "react"; +import { useTranslation } from "react-i18next"; +import type { PhaseExecution } from "@/types/agentEvents"; +import Markdown from "@/lib/Markdown"; + +interface AgentPhaseCardProps { + phase: PhaseExecution; + isActive: boolean; + index: number; +} + +/** + * AgentPhaseCard displays a single phase/node in the agent execution timeline. + * Mirrors the ToolCallCard design with consistent colors, spacing, and interactions. + * + * States: pending, running, completed, failed, skipped + */ +export default function AgentPhaseCard({ + phase, + isActive, + index, +}: AgentPhaseCardProps) { + const { t } = useTranslation(); + const [isExpanded, setIsExpanded] = useState(isActive); + + // Auto-expand when phase becomes active + useEffect(() => { + if (isActive && !isExpanded) { + setIsExpanded(true); + } + }, [isActive, isExpanded]); + + // Format duration in human-readable format + const formatDuration = (ms?: number): string => { + if (ms === undefined) return "-"; + if (ms < 1000) return `${ms}ms`; + const seconds = ms / 1000; + if (seconds < 60) return `${seconds.toFixed(1)}s`; + const minutes = Math.floor(seconds / 60); + const remainingSeconds = (seconds % 60).toFixed(0); + return `${minutes}m ${remainingSeconds}s`; + }; + + // Calculate current duration for running phases + const getCurrentDuration = (): number | undefined => { + if (phase.status === "running" && phase.startedAt) { + return Date.now() - phase.startedAt; + } + return phase.durationMs; + }; + + // Get status icon + const getStatusIcon = () => { + switch (phase.status) { + case "pending": + return ; + case "running": + return ( + + + + ); + case "completed": + return ( + + + + ); + case "failed": + return ( + + + + ); + case "skipped": + return ( + + ); + default: + return null; + } + }; + + // Get card styling based on status + const getCardStyle = () => { + const baseStyle = "rounded-lg border transition-all duration-300"; + + switch (phase.status) { + case "pending": + return `${baseStyle} border-neutral-200 bg-neutral-50/50 dark:border-neutral-700 dark:bg-neutral-800/30`; + case "running": + return `${baseStyle} border-blue-400 bg-gradient-to-br from-blue-50/80 via-indigo-50/60 to-cyan-50/80 dark:border-blue-500/50 dark:from-blue-950/40 dark:via-indigo-950/30 dark:to-cyan-950/40 shadow-sm`; + case "completed": + return `${baseStyle} border-green-300 bg-green-50/50 dark:border-green-700/50 dark:bg-green-900/20`; + case "failed": + return `${baseStyle} border-red-400 bg-red-50/50 dark:border-red-700/50 dark:bg-red-900/20`; + case "skipped": + return `${baseStyle} border-yellow-300 bg-yellow-50/50 dark:border-yellow-700/50 dark:bg-yellow-900/20`; + default: + return `${baseStyle} border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-800`; + } + }; + + // Get text color based on status + const getTextColor = () => { + switch (phase.status) { + case "pending": + return "text-neutral-600 dark:text-neutral-400"; + case "running": + return "text-blue-900 dark:text-blue-100"; + case "completed": + return "text-green-900 dark:text-green-100"; + case "failed": + return "text-red-900 dark:text-red-100"; + case "skipped": + return "text-yellow-900 dark:text-yellow-100"; + default: + return "text-neutral-700 dark:text-neutral-300"; + } + }; + + // Get status badge + const getStatusBadge = () => { + const badgeClass = "rounded-full px-2 py-0.5 text-xs font-medium"; + + switch (phase.status) { + case "pending": + return ( + + {t("app.chat.agent.phase.pending", { defaultValue: "Pending" })} + + ); + case "running": + return ( + + {t("app.chat.agent.phase.running", { defaultValue: "Running" })} + + ); + case "completed": + return ( + + {t("app.chat.agent.phase.completed", { defaultValue: "Completed" })} + + ); + case "failed": + return ( + + {t("app.chat.agent.phase.failed", { defaultValue: "Failed" })} + + ); + case "skipped": + return ( + + {t("app.chat.agent.phase.skipped", { defaultValue: "Skipped" })} + + ); + default: + return null; + } + }; + + const hasContent = + phase.streamedContent || phase.outputSummary || phase.nodes.length > 0; + const canExpand = hasContent && phase.status !== "pending"; + + return ( + + {/* Phase Card */} +
+ {/* Shimmer effect for running phases */} + {phase.status === "running" && ( + + )} + + {/* Header - Always visible */} + + + {/* Expandable Content */} + + {isExpanded && hasContent && ( + +
+ {/* Streaming Content */} + {phase.streamedContent && ( + + + + )} + + {/* Output Summary */} + {phase.outputSummary && !phase.streamedContent && ( +

+ {phase.outputSummary} +

+ )} + + {/* Nested Nodes (if any) */} + {phase.nodes.length > 0 && ( +
+
+ {t("app.chat.agent.nodes", { defaultValue: "Nodes" })} +
+ {phase.nodes.map((node) => ( +
+ {node.status === "running" ? ( + + ) : node.status === "completed" ? ( + + ) : ( + + )} + + {node.name} + + {node.durationMs && ( + + {formatDuration(node.durationMs)} + + )} +
+ ))} +
+ )} +
+
+ )} +
+
+
+ ); +} diff --git a/web/src/components/layouts/components/AgentPhaseItem.tsx b/web/src/components/layouts/components/AgentPhaseItem.tsx new file mode 100644 index 00000000..be64a426 --- /dev/null +++ b/web/src/components/layouts/components/AgentPhaseItem.tsx @@ -0,0 +1,172 @@ +import { motion, AnimatePresence } from "framer-motion"; +import { + CheckCircle2, + XCircle, + Loader2, + Circle, + ChevronRight, + SkipForward, +} from "lucide-react"; +import { useState } from "react"; +import type { PhaseExecution } from "@/types/agentEvents"; +import Markdown from "@/lib/Markdown"; + +interface AgentPhaseItemProps { + phase: PhaseExecution; + formatDuration: (ms?: number) => string; +} + +/** + * AgentPhaseItem displays a single phase in the agent execution timeline. + * Can be expanded to show node-level details or streamed content. + */ +export default function AgentPhaseItem({ + phase, + formatDuration, +}: AgentPhaseItemProps) { + const [isExpanded, setIsExpanded] = useState(false); + const hasNodes = phase.nodes && phase.nodes.length > 0; + const hasExpandableContent = hasNodes || !!phase.streamedContent; + + // Get status icon for phase + const getPhaseStatusIcon = () => { + switch (phase.status) { + case "running": + return ; + case "completed": + return ; + case "failed": + return ; + case "skipped": + return ; + case "pending": + default: + return ( + + ); + } + }; + + // Get node status icon + const getNodeStatusIcon = (status: string) => { + switch (status) { + case "running": + return ; + case "completed": + return ; + case "failed": + return ; + case "skipped": + return ; + default: + return ; + } + }; + + return ( +
+ {/* Phase row */} + + + {/* Expanded node details */} + + {isExpanded && hasNodes && ( + +
+ {phase.nodes.map((node) => ( +
+ {/* Node status icon */} + {getNodeStatusIcon(node.status)} + + {/* Node name */} + + {node.name} + {node.type && ( + + ({node.type}) + + )} + + + {/* Node duration */} + {node.durationMs !== undefined && ( + + {formatDuration(node.durationMs)} + + )} +
+ ))} +
+
+ )} +
+ + {/* Expandable streamed content */} + + {isExpanded && phase.streamedContent && ( + +
+
+ +
+
+
+ )} +
+ + {/* Output summary (if available and phase is completed) */} + {phase.outputSummary && phase.status === "completed" && ( +
+ {phase.outputSummary} +
+ )} +
+ ); +} diff --git a/web/src/components/layouts/components/AgentProgressBar.tsx b/web/src/components/layouts/components/AgentProgressBar.tsx new file mode 100644 index 00000000..3909c402 --- /dev/null +++ b/web/src/components/layouts/components/AgentProgressBar.tsx @@ -0,0 +1,105 @@ +import { motion } from "framer-motion"; +import type { PhaseExecution } from "@/types/agentEvents"; + +interface AgentProgressBarProps { + percent: number; + phases?: PhaseExecution[]; +} + +/** + * AgentProgressBar displays an animated progress bar with optional phase indicators. + */ +export default function AgentProgressBar({ + percent, + phases = [], +}: AgentProgressBarProps) { + // Calculate phase positions if phases are provided + const phaseCount = phases.length; + const completedPhases = phases.filter( + (p) => p.status === "completed" || p.status === "skipped", + ).length; + const runningPhase = phases.findIndex((p) => p.status === "running"); + + // Use phase-based progress if phases exist, otherwise use percent + const displayPercent = + phaseCount > 0 + ? Math.min( + 100, + ((completedPhases + (runningPhase >= 0 ? 0.5 : 0)) / phaseCount) * + 100, + ) + : percent; + + return ( +
+ {/* Background track */} +
+ {/* Animated progress fill */} + + + {/* Shimmer effect on the progress bar */} + 0 && displayPercent < 100 ? [0.3, 0.7, 0.3] : 0, + }} + transition={{ + duration: 1.5, + repeat: Infinity, + ease: "easeInOut", + }} + /> +
+ + {/* Phase dots (if phases exist) */} + {phaseCount > 1 && ( +
+ {phases.map((phase, index) => { + const position = ((index + 1) / phaseCount) * 100; + return ( + + ); + })} +
+ )} + + {/* Percentage text */} + {displayPercent > 0 && ( +
+ {Math.round(displayPercent)}% +
+ )} +
+ ); +} diff --git a/web/src/components/layouts/components/ChatBubble.tsx b/web/src/components/layouts/components/ChatBubble.tsx index 3d7e5f20..77b345b0 100644 --- a/web/src/components/layouts/components/ChatBubble.tsx +++ b/web/src/components/layouts/components/ChatBubble.tsx @@ -7,6 +7,7 @@ import type { Message } from "@/store/types"; import { CheckIcon, ClipboardDocumentIcon } from "@heroicons/react/24/outline"; import { motion } from "framer-motion"; import { useDeferredValue, useMemo, useState } from "react"; +import AgentExecutionTimeline from "./AgentExecutionTimeline"; import LoadingMessage from "./LoadingMessage"; import MessageAttachments from "./MessageAttachments"; import { SearchCitations } from "./SearchCitations"; @@ -35,6 +36,7 @@ function ChatBubble({ message }: ChatBubbleProps) { citations, isThinking, thinkingContent, + agentExecution, } = message; // 流式消息打字效果 @@ -205,6 +207,17 @@ function ChatBubble({ message }: ChatBubbleProps) { : "text-sm text-neutral-700 dark:text-neutral-300" }`} > + {/* Agent execution timeline - only show for multi-phase agent types */} + {/* Skip timeline for react/simple agents - they just show content directly */} + {!isUserMessage && + agentExecution && + agentExecution.agentType !== "react" && ( + + )} + {/* Thinking content - shown before main response for assistant messages */} {!isUserMessage && thinkingContent && ( ) : ( + // Show markdownContent when: + // 1. No agentExecution (regular chat) + // 2. agentExecution is react type (simple agent without timeline) + // For multi-phase agents, content is shown in timeline phases + (!agentExecution || agentExecution.agentType === "react") && markdownContent )} + + {/* For multi-phase agents, show final report below timeline when completed */} + {!isUserMessage && + !isStreaming && + agentExecution && + agentExecution.agentType !== "react" && + agentExecution.status !== "running" && + agentExecution.phases.length > 0 && + (() => { + const finalPhase = + agentExecution.phases[agentExecution.phases.length - 1]; + if (finalPhase?.streamedContent) { + return ( +
+ +
+ ); + } + return null; + })()} + {isStreaming && !isLoading && ( { + if (isActive && !isExpanded) { + setIsExpanded(true); + } + }, [isActive, isExpanded]); + + // Early return after all hooks + if (!node) return null; + + // Format duration + const formatDuration = (ms?: number): string => { + if (ms === undefined) return "-"; + if (ms < 1000) return `${ms}ms`; + const seconds = ms / 1000; + if (seconds < 60) return `${seconds.toFixed(1)}s`; + const minutes = Math.floor(seconds / 60); + const remainingSeconds = (seconds % 60).toFixed(0); + return `${minutes}m ${remainingSeconds}s`; + }; + + // Calculate current duration for running nodes + const getCurrentDuration = (): number | undefined => { + if (node.status === "running" && node.started_at) { + return Date.now() - node.started_at; + } + return node.duration_ms; + }; + + // Get status icon + const getStatusIcon = () => { + switch (node.status) { + case "pending": + return ( + + ); + case "running": + return ( + + + + ); + case "completed": + return ( + + + + ); + case "failed": + return ( + + + + ); + case "skipped": + return ( + + ); + default: + return null; + } + }; + + // Get card styling based on status + const getCardStyle = () => { + const baseStyle = + "transition-all duration-200 hover:bg-black/5 dark:hover:bg-white/5"; + + switch (node.status) { + case "pending": + return `${baseStyle} border-neutral-200 bg-neutral-50/50 dark:border-neutral-700 dark:bg-neutral-800/30`; + case "running": + return `${baseStyle} border-blue-400 bg-gradient-to-br from-blue-50/80 via-indigo-50/60 to-cyan-50/80 dark:border-blue-500/50 dark:from-blue-950/40 dark:via-indigo-950/30 dark:to-cyan-950/40`; + case "completed": + return `${baseStyle} border-green-300 bg-green-50/50 dark:border-green-700/50 dark:bg-green-900/20`; + case "failed": + return `${baseStyle} border-red-400 bg-red-50/50 dark:border-red-700/50 dark:bg-red-900/20`; + case "skipped": + return `${baseStyle} border-yellow-300 bg-yellow-50/50 dark:border-yellow-700/50 dark:bg-yellow-900/20`; + default: + return `${baseStyle} border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-800`; + } + }; + + // Get text color based on status + const getTextColor = () => { + switch (node.status) { + case "pending": + return "text-neutral-600 dark:text-neutral-400"; + case "running": + return "text-blue-900 dark:text-blue-100"; + case "completed": + return "text-green-900 dark:text-green-100"; + case "failed": + return "text-red-900 dark:text-red-100"; + case "skipped": + return "text-yellow-900 dark:text-yellow-100"; + default: + return "text-neutral-700 dark:text-neutral-300"; + } + }; + + const hasContent = content || node.output_summary; + const canExpand = hasContent && node.status !== "pending"; + + return ( + +
+ {/* Shimmer effect for running nodes */} + {node.status === "running" && ( + + )} + + {/* Header - Always visible */} + + + {/* Expandable Content */} + + {isExpanded && hasContent && ( + +
+ {/* Streaming Content */} + {content && ( + + + + )} + + {/* Output Summary */} + {node.output_summary && !content && ( +

+ {node.output_summary} +

+ )} + + {/* Error Display */} + {metadata.error && node.status === "failed" && ( +
+
+ {metadata.error.type} +
+
+ {metadata.error.message} +
+
+ )} +
+
+ )} +
+
+
+ ); +} diff --git a/web/src/components/modals/AddAgentModal.tsx b/web/src/components/modals/AddAgentModal.tsx index 4bc1da14..b84608a1 100644 --- a/web/src/components/modals/AddAgentModal.tsx +++ b/web/src/components/modals/AddAgentModal.tsx @@ -1,9 +1,22 @@ import { Modal } from "@/components/animate-ui/primitives/headless/modal"; import { Input } from "@/components/base/Input"; import { useXyzen } from "@/store"; -import type { Agent } from "@/types/agents"; -import { Button, Field, Label } from "@headlessui/react"; -import { PlusIcon } from "@heroicons/react/24/outline"; +import type { Agent, SystemAgentTemplate } from "@/types/agents"; +import { + Button, + Field, + Label, + Tab, + TabGroup, + TabList, + TabPanel, + TabPanels, +} from "@headlessui/react"; +import { + PlusIcon, + BeakerIcon, + SparklesIcon, +} from "@heroicons/react/24/outline"; import React, { useEffect, useState } from "react"; import { McpServerItem } from "./McpServerItem"; @@ -12,15 +25,27 @@ interface AddAgentModalProps { onClose: () => void; } +type TabMode = "custom" | "system"; + function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { const { createAgent, + createAgentFromTemplate, isCreatingAgent, mcpServers, fetchMcpServers, openAddMcpServerModal, + systemAgentTemplates, + templatesLoading, + fetchSystemAgentTemplates, } = useXyzen(); + const [tabMode, setTabMode] = useState("custom"); + const [selectedTemplateKey, setSelectedTemplateKey] = useState( + null, + ); + const [customName, setCustomName] = useState(""); + const [agent, setAgent] = useState< Omit< Agent, @@ -39,12 +64,13 @@ function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { const [mcpServerIds, setMcpServerIds] = useState([]); const [isSubmitting, setIsSubmitting] = useState(false); - // Fetch MCP servers when modal opens + // Fetch MCP servers and system agent templates when modal opens useEffect(() => { if (isOpen) { fetchMcpServers(); + fetchSystemAgentTemplates(); } - }, [isOpen, fetchMcpServers]); + }, [isOpen, fetchMcpServers, fetchSystemAgentTemplates]); const handleChange = ( e: React.ChangeEvent, @@ -61,7 +87,15 @@ function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { ); }; - const buildAgentPayload = () => ({ + const handleTemplateSelect = (template: SystemAgentTemplate) => { + setSelectedTemplateKey(template.key); + // Pre-fill name with template name if custom name is empty + if (!customName) { + setCustomName(template.metadata.name); + } + }; + + const buildCustomAgentPayload = () => ({ ...agent, mcp_server_ids: mcpServerIds, user_id: "temp", // Backend will get this from auth token @@ -77,11 +111,23 @@ function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { setIsSubmitting(true); try { - if (!agent.name) { - alert("助手名称不能为空"); - return; + if (tabMode === "custom") { + if (!agent.name) { + alert("助手名称不能为空"); + return; + } + await createAgent(buildCustomAgentPayload()); + } else { + if (!selectedTemplateKey) { + alert("请选择一个系统助手"); + return; + } + // Use the new from-template endpoint + await createAgentFromTemplate( + selectedTemplateKey, + customName || undefined, + ); } - await createAgent(buildAgentPayload()); handleClose(); } catch (error) { console.error("Failed to create agent:", error); @@ -91,7 +137,11 @@ function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { } }; - const submitDisabled = isSubmitting || isCreatingAgent || !agent.name; + const submitDisabled = + tabMode === "custom" + ? isSubmitting || isCreatingAgent || !agent.name + : isSubmitting || isCreatingAgent || !selectedTemplateKey; + const submitLabel = isSubmitting || isCreatingAgent ? "创建中..." : "创建助手"; @@ -102,127 +152,332 @@ function AddAgentModal({ isOpen, onClose }: AddAgentModalProps) { prompt: "", }); setMcpServerIds([]); + setSelectedTemplateKey(null); + setCustomName(""); + setTabMode("custom"); onClose(); }; + const handleTabChange = (index: number) => { + setTabMode(index === 0 ? "custom" : "system"); + }; + return ( -

- 创建一个新的 AI 助手,可以配置专属提示词和工具。 -

- -
- - - - - - - - - - - - -