diff --git a/python/packages/kagent-core/pyproject.toml b/python/packages/kagent-core/pyproject.toml index 8efb26a04..1aa25b21d 100644 --- a/python/packages/kagent-core/pyproject.toml +++ b/python/packages/kagent-core/pyproject.toml @@ -9,7 +9,7 @@ description = "kagent common library for kagent python packages" readme = "README.md" requires-python = ">=3.11.0" dependencies = [ - "a2a-sdk>=0.2.16", + "a2a-sdk[http-server]>=0.3.9", "opentelemetry-api>=1.36.0", "opentelemetry-sdk>=1.36.0", "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", diff --git a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py index 5c4155315..0665d4123 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py @@ -5,8 +5,24 @@ A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, A2A_DATA_PART_METADATA_TYPE_KEY, + KAGENT_HITL_DECISION_TYPE_APPROVE, + KAGENT_HITL_DECISION_TYPE_DENY, + KAGENT_HITL_DECISION_TYPE_KEY, + KAGENT_HITL_DECISION_TYPE_REJECT, + KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL, + KAGENT_HITL_RESUME_KEYWORDS_APPROVE, + KAGENT_HITL_RESUME_KEYWORDS_DENY, get_kagent_metadata_key, ) +from ._hitl import ( + DecisionType, + ToolApprovalRequest, + escape_markdown_backticks, + extract_decision_from_message, + format_tool_approval_text_parts, + handle_tool_approval_interrupt, + is_input_required_task, +) from ._requests import KAgentRequestContextBuilder from ._task_result_aggregator import TaskResultAggregator from ._task_store import KAgentTaskStore @@ -22,4 +38,22 @@ "A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT", "A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE", "TaskResultAggregator", + # HITL constants + "KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL", + "KAGENT_HITL_DECISION_TYPE_KEY", + "KAGENT_HITL_DECISION_TYPE_APPROVE", + "KAGENT_HITL_DECISION_TYPE_DENY", + "KAGENT_HITL_DECISION_TYPE_REJECT", + "KAGENT_HITL_RESUME_KEYWORDS_APPROVE", + "KAGENT_HITL_RESUME_KEYWORDS_DENY", + # HITL types + "DecisionType", + "ToolApprovalRequest", + # HITL utilities + "escape_markdown_backticks", + "extract_decision_from_message", + "format_tool_approval_text_parts", + "is_input_required_task", + # HITL handlers + "handle_tool_approval_interrupt", ] diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_consts.py b/python/packages/kagent-core/src/kagent/core/a2a/_consts.py index cfcb33745..ac3eecb7f 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/_consts.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/_consts.py @@ -23,3 +23,13 @@ def get_kagent_metadata_key(key: str) -> str: if not key: raise ValueError("Metadata key cannot be empty or None") return f"{KAGENT_METADATA_KEY_PREFIX}{key}" + + +# Human-in-the-Loop (HITL) Constants +KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL = "tool_approval" +KAGENT_HITL_DECISION_TYPE_KEY = "decision_type" +KAGENT_HITL_DECISION_TYPE_APPROVE = "approve" +KAGENT_HITL_DECISION_TYPE_DENY = "deny" +KAGENT_HITL_DECISION_TYPE_REJECT = "reject" +KAGENT_HITL_RESUME_KEYWORDS_APPROVE = ["approved", "approve", "proceed", "yes", "continue"] +KAGENT_HITL_RESUME_KEYWORDS_DENY = ["denied", "deny", "reject", "no", "cancel", "stop"] diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_hitl.py b/python/packages/kagent-core/src/kagent/core/a2a/_hitl.py new file mode 100644 index 000000000..9d55496b8 --- /dev/null +++ b/python/packages/kagent-core/src/kagent/core/a2a/_hitl.py @@ -0,0 +1,315 @@ +"""Human-in-the-Loop (HITL) support for kagent executors. + +This module provides types, utilities, and handlers for implementing +human-in-the-loop workflows in kagent agent executors using A2A protocol primitives. +""" + +import logging +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, Literal + +from a2a.server.events.event_queue import EventQueue +from a2a.server.tasks import TaskStore +from a2a.types import ( + DataPart, + Message, + Part, + Role, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + +from ._consts import ( + KAGENT_HITL_DECISION_TYPE_APPROVE, + KAGENT_HITL_DECISION_TYPE_DENY, + KAGENT_HITL_DECISION_TYPE_KEY, + KAGENT_HITL_DECISION_TYPE_REJECT, + KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL, + KAGENT_HITL_RESUME_KEYWORDS_APPROVE, + KAGENT_HITL_RESUME_KEYWORDS_DENY, + get_kagent_metadata_key, +) + +logger = logging.getLogger(__name__) + +# Type definitions + +DecisionType = Literal["approve", "deny", "reject"] +"""Type for user decisions in HITL workflows.""" + + +@dataclass +class ToolApprovalRequest: + """Generic structure for a tool call requiring approval. + + Any agent framework can map their tool calls to this structure. + + Attributes: + name: The name of the tool/function being called + args: Dictionary of arguments to pass to the tool + id: Optional unique identifier for this specific tool call + """ + + name: str + args: dict[str, Any] + id: str | None = None + + +# Utility functions + + +def escape_markdown_backticks(text: str) -> str: + """Escape backticks in text to prevent markdown formatting issues. + + Used when displaying code, tool names, or arguments in markdown-formatted + approval messages. + + Args: + text: Text that may contain backticks + + Returns: + Text with all backticks escaped with backslash + + Examples: + >>> escape_markdown_backticks("function `foo`") + 'function \\`foo\\`' + """ + return str(text).replace("`", "\\`") + + +def is_input_required_task(task_state: TaskState | None) -> bool: + """Check if task state indicates waiting for user input. + + Args: + task_state: Current task state, or None if no task + + Returns: + True if task is in input_required state + """ + return task_state == TaskState.input_required + + +def extract_decision_from_data_part(data: dict) -> DecisionType | None: + """Extract decision type from structured DataPart. + + Looks for the decision_type key in the data dictionary and validates + it's a known decision value. + + Args: + data: DataPart.data dictionary + + Returns: + Decision type if found and valid, None otherwise + """ + decision = data.get(KAGENT_HITL_DECISION_TYPE_KEY) + if decision in ( + KAGENT_HITL_DECISION_TYPE_APPROVE, + KAGENT_HITL_DECISION_TYPE_DENY, + KAGENT_HITL_DECISION_TYPE_REJECT, + ): + return decision + return None + + +def extract_decision_from_text(text: str) -> DecisionType | None: + """Extract decision from text using keyword matching. + + Searches for approval or denial keywords in the text (case-insensitive). + Denial keywords take priority if both are present (to avoid accidental approval). + + Args: + text: User input text + + Returns: + "deny" if denial keywords found, "approve" if approval keywords found, + None if no keywords found + """ + text_lower = text.lower() + + # Check deny keywords first (safer - prevents accidental approval) + if any(keyword in text_lower for keyword in KAGENT_HITL_RESUME_KEYWORDS_DENY): + return KAGENT_HITL_DECISION_TYPE_DENY + + # Check approve keywords + if any(keyword in text_lower for keyword in KAGENT_HITL_RESUME_KEYWORDS_APPROVE): + return KAGENT_HITL_DECISION_TYPE_APPROVE + + return None + + +def extract_decision_from_message(message: Message | None) -> DecisionType | None: + """Extract decision from A2A message using two-tier detection. + + Priority: + 1. Structured DataPart with decision_type field (most reliable) + 2. Keyword matching in TextPart (fallback for human input) + + DataPart is checked across all parts first before falling back to TextPart, + ensuring structured decisions always take precedence. + + Args: + message: A2A message from user + + Returns: + Decision type if found, None otherwise + """ + if not message or not message.parts: + return None + + # Priority 1: Scan all parts for DataPart with decision (most reliable) + for part in message.parts: + # Access .root for RootModel union types + if not hasattr(part, "root"): + continue + + inner = part.root + + if isinstance(inner, DataPart): + decision = extract_decision_from_data_part(inner.data) + if decision: + return decision + + # Priority 2: Fallback to TextPart keyword matching + for part in message.parts: + if not hasattr(part, "root"): + continue + + inner = part.root + + if isinstance(inner, TextPart): + if inner.text and isinstance(inner.text, str): + decision = extract_decision_from_text(inner.text) + if decision: + return decision + + return None + + +def format_tool_approval_text_parts( + action_requests: list[ToolApprovalRequest], +) -> list[Part]: + """Format tool approval requests as human-readable TextParts. + + Creates a formatted approval message listing all tools and their arguments + with proper markdown escaping to prevent rendering issues. + + Args: + action_requests: List of tool approval request objects + + Returns: + List of Part objects containing formatted approval message + """ + parts = [] + + # Add header + parts.append(Part(TextPart(text="**Approval Required**\n\n"))) + parts.append(Part(TextPart(text="The following actions require your approval:\n\n"))) + + # List each action + for action in action_requests: + tool_name = action.name + tool_args = action.args + + # Escape backticks to prevent markdown breaking + escaped_tool_name = escape_markdown_backticks(tool_name) + parts.append(Part(TextPart(text=f"**Tool**: `{escaped_tool_name}`\n"))) + parts.append(Part(TextPart(text="**Arguments**:\n"))) + + for key, value in tool_args.items(): + escaped_key = escape_markdown_backticks(key) + escaped_value = escape_markdown_backticks(value) + parts.append(Part(TextPart(text=f" • {escaped_key}: `{escaped_value}`\n"))) + + parts.append(Part(TextPart(text="\n"))) + + return parts + + +# High-level handlers + + +async def handle_tool_approval_interrupt( + action_requests: list[ToolApprovalRequest], + task_id: str, + context_id: str, + event_queue: EventQueue, + task_store: TaskStore, + app_name: str | None = None, + review_configs: list[dict[str, Any]] | None = None, +) -> None: + """Send input_required event for tool approval. + + This is a framework-agnostic handler that any executor can call when + it needs user approval for tool calls. It formats an approval message, + sends an input_required event, and waits for the task to be saved. + + Args: + action_requests: List of tool calls requiring approval + task_id: A2A task ID + context_id: A2A context ID + event_queue: Event queue for publishing events + task_store: Task store for synchronization + app_name: Optional application name for metadata + review_configs: Optional framework-specific review configurations + + Raises: + TimeoutError: If task save doesn't complete within 5 seconds (logged as warning) + """ + # Build human-readable message + text_parts = format_tool_approval_text_parts(action_requests) + + # Build structured DataPart for machine processing (client can parse this) + interrupt_data = { + "interrupt_type": KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL, + "action_requests": [{"name": req.name, "args": req.args, "id": req.id} for req in action_requests], + } + + if review_configs: + interrupt_data["review_configs"] = review_configs + + data_part = Part( + DataPart( + data=interrupt_data, + metadata={get_kagent_metadata_key("type"): "interrupt_data"}, + ) + ) + + # Combine message parts + message_parts = text_parts + [data_part] + + # Build event metadata + event_metadata = {"interrupt_type": KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL} + if app_name: + event_metadata["app_name"] = app_name + + # Send input_required event + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task_id, + status=TaskStatus( + state=TaskState.input_required, + timestamp=datetime.now(UTC).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=message_parts, + ), + ), + context_id=context_id, + final=False, # Not final - waiting for user input + metadata=event_metadata, + ) + ) + + logger.info(f"Interrupt detected, sent input_required event for task {task_id} with {len(action_requests)} actions") + + # Wait for the event consumer to persist the task (event-based sync) + # This prevents race condition where approval arrives before task is saved + try: + await task_store.wait_for_save(task_id, timeout=5.0) + except TimeoutError: + logger.warning("Task save event timeout, proceeding anyway") diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_task_store.py b/python/packages/kagent-core/src/kagent/core/a2a/_task_store.py index bee314f3f..97d1c2161 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/_task_store.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/_task_store.py @@ -1,9 +1,25 @@ +import asyncio + import httpx from a2a.server.tasks import TaskStore from a2a.types import Task +from pydantic import BaseModel from typing_extensions import override +class KAgentTaskResponse(BaseModel): + """Wrapper for KAgent controller API responses. + + The KAgent Go controller wraps all task responses in a StandardResponse envelope + with the format: {"error": bool, "data": T, "message": str}. + This model unwraps that envelope to extract the actual Task object. + """ + + error: bool + data: Task | None = None + message: str | None = None + + class KAgentTaskStore(TaskStore): """ A task store that persists A2A tasks to KAgent via REST API. @@ -16,13 +32,16 @@ def __init__(self, client: httpx.AsyncClient): client: HTTP client configured with KAgent base URL """ self.client = client + # Event-based sync: track pending save operations + self._save_events: dict[str, asyncio.Event] = {} @override - async def save(self, task: Task) -> None: + async def save(self, task: Task, context=None) -> None: """Save a task to KAgent. Args: task: The task to save + context: Server call context (unused, for a2a-sdk 0.3+ compatibility) Raises: httpx.HTTPStatusError: If the API request fails @@ -30,12 +49,17 @@ async def save(self, task: Task) -> None: response = await self.client.post("/api/tasks", json=task.model_dump(mode="json")) response.raise_for_status() + # Signal that save completed (event-based sync) + if task.id in self._save_events: + self._save_events[task.id].set() + @override - async def get(self, task_id: str) -> Task | None: + async def get(self, task_id: str, context=None) -> Task | None: """Retrieve a task from KAgent. Args: task_id: The ID of the task to retrieve + context: Server call context (unused, for a2a-sdk 0.3+ compatibility) Returns: The task if found, None otherwise @@ -47,17 +71,43 @@ async def get(self, task_id: str) -> Task | None: if response.status_code == 404: return None response.raise_for_status() - return Task.model_validate(response.json()) + + # Unwrap the StandardResponse envelope from the Go controller + wrapped = KAgentTaskResponse.model_validate(response.json()) + return wrapped.data @override - async def delete(self, task_id: str) -> None: + async def delete(self, task_id: str, context=None) -> None: """Delete a task from KAgent. Args: task_id: The ID of the task to delete + context: Server call context (unused, for a2a-sdk 0.3+ compatibility) Raises: httpx.HTTPStatusError: If the API request fails """ response = await self.client.delete(f"/api/tasks/{task_id}") response.raise_for_status() + + async def wait_for_save(self, task_id: str, timeout: float = 5.0) -> None: + """Wait for a task to be saved (event-based sync). + + This method is used to synchronize with the save operation instead of + using arbitrary sleep delays. It's particularly useful after interrupts + to ensure the task state is persisted before resuming. + + Args: + task_id: The ID of the task to wait for + timeout: Maximum time to wait in seconds (default: 5.0) + + Raises: + asyncio.TimeoutError: If the save doesn't complete within timeout + """ + event = asyncio.Event() + self._save_events[task_id] = event + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + finally: + # Clean up the event + self._save_events.pop(task_id, None) diff --git a/python/packages/kagent-core/tests/test_hitl_consts.py b/python/packages/kagent-core/tests/test_hitl_consts.py new file mode 100644 index 000000000..a5c4727d9 --- /dev/null +++ b/python/packages/kagent-core/tests/test_hitl_consts.py @@ -0,0 +1,29 @@ +"""Tests for HITL constants.""" + +from kagent.core.a2a import ( + KAGENT_HITL_DECISION_TYPE_APPROVE, + KAGENT_HITL_DECISION_TYPE_DENY, + KAGENT_HITL_DECISION_TYPE_KEY, + KAGENT_HITL_DECISION_TYPE_REJECT, + KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL, + KAGENT_HITL_RESUME_KEYWORDS_APPROVE, + KAGENT_HITL_RESUME_KEYWORDS_DENY, +) + + +def test_hitl_constants(): + """Test all HITL constants are defined with expected values.""" + # Interrupt types + assert KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL == "tool_approval" + + # Decision types + assert KAGENT_HITL_DECISION_TYPE_KEY == "decision_type" + assert KAGENT_HITL_DECISION_TYPE_APPROVE == "approve" + assert KAGENT_HITL_DECISION_TYPE_DENY == "deny" + assert KAGENT_HITL_DECISION_TYPE_REJECT == "reject" + + # Resume keywords + assert "approved" in KAGENT_HITL_RESUME_KEYWORDS_APPROVE + assert "proceed" in KAGENT_HITL_RESUME_KEYWORDS_APPROVE + assert "denied" in KAGENT_HITL_RESUME_KEYWORDS_DENY + assert "cancel" in KAGENT_HITL_RESUME_KEYWORDS_DENY diff --git a/python/packages/kagent-core/tests/test_hitl_handlers.py b/python/packages/kagent-core/tests/test_hitl_handlers.py new file mode 100644 index 000000000..2edf5d850 --- /dev/null +++ b/python/packages/kagent-core/tests/test_hitl_handlers.py @@ -0,0 +1,100 @@ +"""Tests for HITL handler functions.""" + +from unittest.mock import AsyncMock, Mock + +import pytest +from a2a.server.events.event_queue import EventQueue +from a2a.server.tasks import TaskStore +from a2a.types import TaskState, TaskStatusUpdateEvent + +from kagent.core.a2a import ( + KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL, + ToolApprovalRequest, + handle_tool_approval_interrupt, +) + + +@pytest.mark.asyncio +async def test_handle_tool_approval_interrupt(): + """Test tool approval interrupt handling with single and multiple actions.""" + # Setup mocks + event_queue = Mock(spec=EventQueue) + event_queue.enqueue_event = AsyncMock() + + task_store = Mock(spec=TaskStore) + task_store.wait_for_save = AsyncMock() + + # Test single action + action_requests = [ToolApprovalRequest(name="search", args={"query": "test"})] + + await handle_tool_approval_interrupt( + action_requests=action_requests, + task_id="task123", + context_id="ctx456", + event_queue=event_queue, + task_store=task_store, + app_name="test_app", + ) + + # Verify event was enqueued + assert event_queue.enqueue_event.call_count == 1 + event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, TaskStatusUpdateEvent) + assert event.task_id == "task123" + assert event.context_id == "ctx456" + assert event.status.state == TaskState.input_required + assert event.final is False + assert event.metadata["interrupt_type"] == KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL + task_store.wait_for_save.assert_called_once_with("task123", timeout=5.0) + + # Reset mocks + event_queue.enqueue_event.reset_mock() + task_store.wait_for_save.reset_mock() + + # Test multiple actions + action_requests = [ + ToolApprovalRequest(name="tool1", args={"a": 1}), + ToolApprovalRequest(name="tool2", args={"b": 2}), + ] + + await handle_tool_approval_interrupt( + action_requests=action_requests, + task_id="task456", + context_id="ctx789", + event_queue=event_queue, + task_store=task_store, + ) + + # Verify event contains all actions + event = event_queue.enqueue_event.call_args[0][0] + message = event.status.message + assert len(message.parts) > 0 + + # Find DataPart with action_requests + data_parts = [p for p in message.parts if hasattr(p, "root") and hasattr(p.root, "data")] + assert len(data_parts) > 0 + + +@pytest.mark.asyncio +async def test_handle_tool_approval_interrupt_timeout(): + """Test that save timeout is handled gracefully.""" + event_queue = Mock(spec=EventQueue) + event_queue.enqueue_event = AsyncMock() + + task_store = Mock(spec=TaskStore) + # Simulate timeout + task_store.wait_for_save = AsyncMock(side_effect=TimeoutError()) + + action_requests = [ToolApprovalRequest(name="test", args={})] + + # Should not raise - timeout is caught and logged + await handle_tool_approval_interrupt( + action_requests=action_requests, + task_id="task123", + context_id="ctx456", + event_queue=event_queue, + task_store=task_store, + ) + + # Event should still be sent even if save times out + event_queue.enqueue_event.assert_called_once() diff --git a/python/packages/kagent-core/tests/test_hitl_utils.py b/python/packages/kagent-core/tests/test_hitl_utils.py new file mode 100644 index 000000000..2f94a0151 --- /dev/null +++ b/python/packages/kagent-core/tests/test_hitl_utils.py @@ -0,0 +1,145 @@ +"""Tests for HITL utility functions.""" + +import pytest +from a2a.types import DataPart, Message, Part, TaskState, TextPart + +from kagent.core.a2a import ( + KAGENT_HITL_DECISION_TYPE_APPROVE, + KAGENT_HITL_DECISION_TYPE_DENY, + KAGENT_HITL_DECISION_TYPE_KEY, + ToolApprovalRequest, + escape_markdown_backticks, + extract_decision_from_message, + format_tool_approval_text_parts, + is_input_required_task, +) + + +def test_escape_markdown_backticks(): + """Test backtick escaping for all cases.""" + assert escape_markdown_backticks("foo`bar") == "foo\\`bar" + assert escape_markdown_backticks("`code` and `more`") == "\\`code\\` and \\`more\\`" + assert escape_markdown_backticks("plain text") == "plain text" + assert escape_markdown_backticks("") == "" + + +def test_is_input_required_task(): + """Test is_input_required_task() for various states.""" + assert is_input_required_task(TaskState.input_required) is True + assert is_input_required_task(TaskState.working) is False + assert is_input_required_task(TaskState.completed) is False + assert is_input_required_task(None) is False + + +def test_extract_decision_datapart(): + """Test DataPart decision extraction (priority 1).""" + # Approve + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(DataPart(data={KAGENT_HITL_DECISION_TYPE_KEY: KAGENT_HITL_DECISION_TYPE_APPROVE}))], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_APPROVE + + # Deny + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(DataPart(data={KAGENT_HITL_DECISION_TYPE_KEY: KAGENT_HITL_DECISION_TYPE_DENY}))], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_DENY + + +def test_extract_decision_textpart(): + """Test TextPart keyword extraction (priority 2).""" + # Approve keyword + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(TextPart(text="I have approved this action"))], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_APPROVE + + # Deny keyword + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(TextPart(text="Request denied, do not proceed"))], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_DENY + + # Case insensitive + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(TextPart(text="APPROVED"))], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_APPROVE + + +def test_extract_decision_priority(): + """Test DataPart takes priority over TextPart.""" + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[ + Part(TextPart(text="approved")), # Would detect as approve + Part(DataPart(data={KAGENT_HITL_DECISION_TYPE_KEY: KAGENT_HITL_DECISION_TYPE_DENY})), # But deny wins + ], + ) + assert extract_decision_from_message(message) == KAGENT_HITL_DECISION_TYPE_DENY + + +def test_extract_decision_edge_cases(): + """Test edge cases: empty message, no parts, no decision.""" + # Empty message + assert extract_decision_from_message(None) is None + + # No parts + message = Message(role="user", message_id="test", task_id="task1", context_id="ctx1", parts=[]) + assert extract_decision_from_message(message) is None + + # No decision found + message = Message( + role="user", + message_id="test", + task_id="task1", + context_id="ctx1", + parts=[Part(TextPart(text="This is just a comment"))], + ) + assert extract_decision_from_message(message) is None + + +def test_format_tool_approval_text_parts(): + """Test formatting tool approval requests with all edge cases.""" + requests = [ + ToolApprovalRequest(name="search", args={"query": "test"}), + ToolApprovalRequest(name="run`code`", args={"cmd": "echo `test`"}), + ToolApprovalRequest(name="reset", args={}), + ] + parts = format_tool_approval_text_parts(requests) + + # Convert to text + text_content = "" + for p in parts: + if hasattr(p, "root") and hasattr(p.root, "text"): + text_content += p.root.text + + # Check structure and content + assert "Approval Required" in text_content + assert "search" in text_content + assert "reset" in text_content + # Check backticks are escaped + assert "\\`" in text_content diff --git a/python/packages/kagent-langgraph/src/kagent/langgraph/_checkpointer.py b/python/packages/kagent-langgraph/src/kagent/langgraph/_checkpointer.py index 85dd6d122..dabe1fc43 100644 --- a/python/packages/kagent-langgraph/src/kagent/langgraph/_checkpointer.py +++ b/python/packages/kagent-langgraph/src/kagent/langgraph/_checkpointer.py @@ -5,6 +5,7 @@ """ import base64 +import json import logging import random from collections.abc import AsyncIterator, Iterator, Sequence @@ -141,7 +142,8 @@ async def aput( thread_id, user_id, checkpoint_ns = self._extract_config_values(config) type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) - serialized_metadata = self.jsonplus_serde.dumps(get_checkpoint_metadata(config, metadata)) + # Serialize metadata as JSON (simpler, no type needed) + serialized_metadata = json.dumps(get_checkpoint_metadata(config, metadata)).encode() # Prepare request data request_data = KAgentCheckpointPayload( thread_id=thread_id, @@ -233,7 +235,7 @@ def _convert_to_checkpoint_tuple( ), metadata=cast( CheckpointMetadata, - self.jsonplus_serde.loads(base64.b64decode(checkpoint_tuple.metadata.encode("ascii"))), + json.loads(base64.b64decode(checkpoint_tuple.metadata.encode("ascii"))), ), parent_config=( { diff --git a/python/packages/kagent-langgraph/src/kagent/langgraph/_converters.py b/python/packages/kagent-langgraph/src/kagent/langgraph/_converters.py index 32e9da3a2..97760c3ec 100644 --- a/python/packages/kagent-langgraph/src/kagent/langgraph/_converters.py +++ b/python/packages/kagent-langgraph/src/kagent/langgraph/_converters.py @@ -4,6 +4,7 @@ within the A2A (Agent-to-Agent) protocol, converting graph events to A2A events. """ +import hashlib import uuid from datetime import UTC, datetime from typing import Any @@ -31,20 +32,20 @@ get_kagent_metadata_key, ) - -def _get_event_metadata(langgraph_event: dict[str, Any]) -> dict[str, Any]: - """Get the metadata from a LangGraph event.""" - return { - "app_name": langgraph_event.get("app_name", ""), - "session_id": langgraph_event.get("session_id", ""), - } +from ._metadata_utils import get_rich_event_metadata async def _convert_langgraph_event_to_a2a( - langgraph_event: dict[str, Any], task_id: str, context_id: str, app_name: str + langgraph_event: dict[str, Any], + task_id: str, + context_id: str, + app_name: str, + sent_message_ids: set[str], ) -> list[TaskStatusUpdateEvent]: - """Convert a LangGraph event to A2A events.""" + """Convert a LangGraph event to A2A events. + Deduplicates messages using sent_message_ids to avoid replaying history. + """ a2a_events: list[TaskStatusUpdateEvent] = [] # LangGraph events have node names as keys, with 'messages' as values @@ -56,8 +57,17 @@ async def _convert_langgraph_event_to_a2a( if not isinstance(messages, list): continue - # Process each message in the event for message in messages: + # Deduplicate using content hash (message.id is often None) + msg_content = f"{type(message).__name__}:{message.content}" + if hasattr(message, "tool_calls") and message.tool_calls: + msg_content += f":tools:{len(message.tool_calls)}" + msg_id = hashlib.md5(msg_content.encode()).hexdigest() + + if msg_id in sent_message_ids: + continue + sent_message_ids.add(msg_id) + if isinstance(message, AIMessage): # Handle AI messages (assistant responses) a2a_message = Message(message_id=str(uuid.uuid4()), role=Role.agent, parts=[]) @@ -83,6 +93,11 @@ async def _convert_langgraph_event_to_a2a( ) ) ) + + # Only send message if it has parts (content or tool calls) + if not a2a_message.parts: + continue + a2a_events.append( TaskStatusUpdateEvent( task_id=task_id, @@ -93,10 +108,10 @@ async def _convert_langgraph_event_to_a2a( ), context_id=context_id, final=False, - metadata={ - "app_name": app_name, - "session_id": context_id, - }, + metadata=get_rich_event_metadata( + app_name=app_name, + session_id=context_id, + ), ) ) @@ -132,34 +147,14 @@ async def _convert_langgraph_event_to_a2a( ), context_id=context_id, final=False, - metadata={ - "app_name": app_name, - "session_id": context_id, - }, + metadata=get_rich_event_metadata( + app_name=app_name, + session_id=context_id, + ), ) ) elif isinstance(message, HumanMessage): - # Handle human messages (user input) - usually for context - if message.content and isinstance(message.content, str) and message.content.strip(): - a2a_events.append( - TaskStatusUpdateEvent( - task_id=task_id, - status=TaskStatus( - state=TaskState.working, - timestamp=datetime.now(UTC).isoformat(), - message=Message( - message_id=str(uuid.uuid4()), - role=Role.agent, - parts=[Part(TextPart(text=message.content))], - ), - ), - context_id=context_id, - final=False, - metadata={ - "app_name": app_name, - "session_id": context_id, - }, - ) - ) + # Skip - user input is already known by caller + pass return a2a_events diff --git a/python/packages/kagent-langgraph/src/kagent/langgraph/_error_mappings.py b/python/packages/kagent-langgraph/src/kagent/langgraph/_error_mappings.py new file mode 100644 index 000000000..f6a85a204 --- /dev/null +++ b/python/packages/kagent-langgraph/src/kagent/langgraph/_error_mappings.py @@ -0,0 +1,48 @@ +"""Error code to user-friendly message mappings for LangGraph events.""" + +# Map common exception types to user-friendly messages +ERROR_TYPE_MESSAGES = { + "TimeoutError": "Request timed out. Please try again or simplify your request.", + "ValidationError": "Invalid input provided. Please check your request format.", + "RateLimitError": "Rate limit exceeded. Please wait a moment and try again.", + "AuthenticationError": "Authentication failed. Please check your credentials.", + "PermissionError": "Permission denied. You don't have access to this resource.", + "ValueError": "Invalid value provided. Please check your input.", + "KeyError": "Required field missing. Please provide all required information.", + "ConnectionError": "Connection failed. Please check your network and try again.", + "HTTPError": "HTTP request failed. The external service may be unavailable.", + "APIError": "API request failed. Please try again later.", + "BadRequestError": "Invalid request. Please check your input and try again.", + "NotFoundError": "Resource not found. Please check the identifier and try again.", + "RuntimeError": "An unexpected error occurred during execution.", +} + +DEFAULT_ERROR_MESSAGE = "An error occurred during processing. Please try again or rephrase your request." + + +def get_user_friendly_error_message(exception: Exception) -> str: + """Get a user-friendly error message for the given exception. + + Args: + exception: The exception that was raised + + Returns: + A user-friendly error message string + """ + error_type = type(exception).__name__ + return ERROR_TYPE_MESSAGES.get(error_type, DEFAULT_ERROR_MESSAGE) + + +def get_error_metadata(exception: Exception) -> dict[str, str]: + """Get metadata dict with error details. + + Args: + exception: The exception that was raised + + Returns: + Dict with error_type and error_detail + """ + return { + "error_type": type(exception).__name__, + "error_detail": str(exception), + } diff --git a/python/packages/kagent-langgraph/src/kagent/langgraph/_executor.py b/python/packages/kagent-langgraph/src/kagent/langgraph/_executor.py index ad5acacfd..e52864818 100644 --- a/python/packages/kagent-langgraph/src/kagent/langgraph/_executor.py +++ b/python/packages/kagent-langgraph/src/kagent/langgraph/_executor.py @@ -13,6 +13,7 @@ from a2a.server.agent_execution import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue +from a2a.server.tasks import TaskStore from a2a.types import ( Artifact, Message, @@ -27,10 +28,20 @@ from langchain_core.runnables import RunnableConfig from pydantic import BaseModel -from kagent.core.a2a import TaskResultAggregator +from kagent.core.a2a import ( + KAGENT_HITL_DECISION_TYPE_DENY, + TaskResultAggregator, + ToolApprovalRequest, + extract_decision_from_message, + get_kagent_metadata_key, + handle_tool_approval_interrupt, + is_input_required_task, +) from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command from ._converters import _convert_langgraph_event_to_a2a +from ._error_mappings import get_error_metadata, get_user_friendly_error_message logger = logging.getLogger(__name__) @@ -110,20 +121,42 @@ async def _stream_graph_events( """Stream LangGraph events and convert them to A2A events.""" task_result_aggregator = TaskResultAggregator() + # Track final state for interrupt detection + final_state: dict[str, Any] | None = None + + # Track message IDs we've already sent to avoid duplicates + sent_message_ids: set[str] = set() + # Stream events from the graph async for event in graph.astream( input_data, config, - stream_mode="updates", # Stream the individual events + stream_mode="updates", ): + # Store final state + final_state = event + # Convert LangGraph events to A2A events a2a_events = await _convert_langgraph_event_to_a2a( - event, context.task_id, context.context_id, self.app_name + event, context.task_id, context.context_id, self.app_name, sent_message_ids ) for a2a_event in a2a_events: task_result_aggregator.process_event(a2a_event) await event_queue.enqueue_event(a2a_event) + # Check for interrupts after streaming completes + if final_state and final_state.get("__interrupt__"): + interrupt_data = final_state["__interrupt__"] + await self._handle_interrupt( + interrupt_data=interrupt_data, + task_id=context.task_id, + context_id=context.context_id, + event_queue=event_queue, + task_store=context.task_store, + ) + # Interrupt detected - input_required event already sent, so return early + return + # Final artifacts are already sent through individual event processing # publish the task result event - this is final @@ -171,12 +204,181 @@ async def _stream_graph_events( ) ) + async def _handle_interrupt( + self, + interrupt_data: list[Any], + task_id: str, + context_id: str, + event_queue: EventQueue, + task_store: TaskStore, + ) -> None: + """Handle interrupt from LangGraph and convert to A2A input_required event. + + This is the LangGraph-specific adapter that extracts interrupt data from + LangGraph's format and delegates to the generic handler in kagent-core. + """ + # Extract interrupt details from LangGraph format + if not interrupt_data: + logger.warning("Empty interrupt data received") + return + + # Safely extract interrupt value (LangGraph-specific format) + first_item = interrupt_data[0] + if hasattr(first_item, "value"): + interrupt_value = first_item.value + elif isinstance(first_item, dict): + interrupt_value = first_item + else: + logger.error(f"Unexpected interrupt data type: {type(first_item)}") + return + + # Extract LangGraph-specific fields + action_requests_raw = interrupt_value.get("action_requests", []) + review_configs = interrupt_value.get("review_configs", []) + + # Convert to generic ToolApprovalRequest format + action_requests = [ + ToolApprovalRequest( + name=action.get("name", "unknown"), + args=action.get("args", {}), + id=action.get("id"), + ) + for action in action_requests_raw + ] + + # Delegate to generic handler in kagent-core + await handle_tool_approval_interrupt( + action_requests=action_requests, + task_id=task_id, + context_id=context_id, + event_queue=event_queue, + task_store=task_store, + app_name=self.app_name, + review_configs=review_configs, + ) + @override async def cancel(self, context: RequestContext, event_queue: EventQueue): """Cancel the execution.""" # TODO: Implement proper cancellation logic if needed raise NotImplementedError("Cancellation is not implemented") + def _is_resume_command(self, context: RequestContext) -> bool: + """Check if message is a resume command for an interrupted task. + + Uses generic utilities from kagent-core for decision extraction. + """ + # Must have an existing task in input_required state to resume + if not context.current_task: + return False + + if not is_input_required_task(context.current_task.status.state): + return False + + # Check if message contains a decision + decision = extract_decision_from_message(context.message) + return decision is not None + + async def _handle_resume( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Resume graph execution after interrupt with user decision.""" + # Extract decision from message using core utility + decision_type = extract_decision_from_message(context.message) + + if not decision_type: + # Security: Default to deny if decision cannot be determined + logger.warning(f"Could not determine decision from message for task {context.task_id}, defaulting to deny") + decision_type = KAGENT_HITL_DECISION_TYPE_DENY + + # Get thread_id from existing task metadata (critical for resume!) + thread_id = None + if context.current_task and context.current_task.metadata: + thread_id = context.current_task.metadata.get("thread_id") + + if not thread_id: + # Fallback to computing from context (same as initial) + thread_id = getattr(context, "session_id", None) or context.context_id + + logger.info( + f"Resuming after interrupt - task_id={context.task_id}, thread_id={thread_id}, decision={decision_type}" + ) + + # Create resume input + resume_input = Command(resume={"decisions": [{"type": decision_type}]}) + + # Create graph config with explicit thread_id + config = { + "configurable": { + "thread_id": thread_id, # Use thread from interrupted task! + "app_name": self.app_name, + }, + "project_name": self.app_name, + "run_name": "kagent-langgraph-resume", + "tags": [ + "kagent", + "langgraph", + "resume", + f"app:{self.app_name}", + f"task:{context.task_id}", + f"context:{context.context_id}", + f"thread:{thread_id}", + ], + "metadata": { + "kagent_app_name": self.app_name, + "a2a_context_id": context.context_id, + "a2a_task_id": context.task_id, + "thread_id": thread_id, + "resume": True, + }, + } + + # Send working status + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(UTC).isoformat(), + ), + context_id=context.context_id, + final=False, + ) + ) + + # Resume graph execution + try: + await asyncio.wait_for( + self._stream_graph_events( + self._graph, + resume_input, # Pass Command instead of messages + config, + context, + event_queue, + ), + timeout=self._config.execution_timeout, + ) + except Exception as e: + logger.error(f"Error during resume: {e}", exc_info=True) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(UTC).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[Part(TextPart(text=f"Resume failed: {str(e)}"))], + ), + ), + context_id=context.context_id, + final=True, + ) + ) + @override async def execute( self, @@ -187,6 +389,13 @@ async def execute( if not context.message: raise ValueError("A2A request must have a message") + # Check if this is a resume command (check before current_task check) + # Resume commands can come as new messages to continue interrupted tasks + if self._is_resume_command(context): + logger.info(f"Resuming task {context.task_id} after interrupt") + await self._handle_resume(context, event_queue) + return + # Send task submitted event for new tasks if not context.current_task: await event_queue.enqueue_event( @@ -202,6 +411,9 @@ async def execute( ) ) + # Calculate and store thread_id for potential resume + thread_id = getattr(context, "session_id", None) or context.context_id + # Send working status await event_queue.enqueue_event( TaskStatusUpdateEvent( @@ -215,6 +427,7 @@ async def execute( metadata={ "app_name": self.app_name, "session_id": getattr(context, "session_id", context.context_id), + "thread_id": thread_id, # Store for resume! }, ) ) @@ -254,6 +467,11 @@ async def execute( ) except Exception as e: logger.error(f"Error during LangGraph execution: {e}", exc_info=True) + + # Get user-friendly message + user_message = get_user_friendly_error_message(e) + error_meta = get_error_metadata(e) + await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -263,10 +481,18 @@ async def execute( message=Message( message_id=str(uuid.uuid4()), role=Role.agent, - parts=[Part(TextPart(text=str(e)))], + parts=[Part(TextPart(text=user_message))], + metadata={ + get_kagent_metadata_key("error_type"): error_meta["error_type"], + get_kagent_metadata_key("error_detail"): error_meta["error_detail"], + }, ), ), context_id=context.context_id, final=True, + metadata={ + get_kagent_metadata_key("error_type"): error_meta["error_type"], + get_kagent_metadata_key("error_detail"): error_meta["error_detail"], + }, ) ) diff --git a/python/packages/kagent-langgraph/src/kagent/langgraph/_metadata_utils.py b/python/packages/kagent-langgraph/src/kagent/langgraph/_metadata_utils.py new file mode 100644 index 000000000..0adfb51d3 --- /dev/null +++ b/python/packages/kagent-langgraph/src/kagent/langgraph/_metadata_utils.py @@ -0,0 +1,65 @@ +"""Metadata utilities for rich event metadata.""" + +import logging +from typing import Any + +from kagent.core.a2a import get_kagent_metadata_key + +logger = logging.getLogger(__name__) + + +def serialize_metadata_value(value: Any) -> str: + """Safely serializes metadata values to string format. + + Args: + value: The value to serialize + + Returns: + String representation of the value + """ + if hasattr(value, "model_dump"): + try: + return str(value.model_dump(exclude_none=True, by_alias=True)) + except Exception as e: + logger.warning(f"Failed to serialize metadata value: {e}") + return str(value) + return str(value) + + +def get_rich_event_metadata( + app_name: str, + session_id: str, + user_id: str | None = None, + invocation_id: str | None = None, + extra_fields: dict[str, Any] | None = None, +) -> dict[str, str]: + """Get rich metadata for A2A events. + + Args: + app_name: Application name + session_id: Session/context ID + user_id: Optional user identifier + invocation_id: Optional invocation/request identifier + extra_fields: Optional additional metadata fields + + Returns: + Dict with namespaced metadata keys + """ + metadata = { + get_kagent_metadata_key("app_name"): app_name, + get_kagent_metadata_key("session_id"): session_id, + } + + # Add optional core fields + if user_id: + metadata[get_kagent_metadata_key("user_id")] = user_id + if invocation_id: + metadata[get_kagent_metadata_key("invocation_id")] = invocation_id + + # Add extra fields if provided + if extra_fields: + for key, value in extra_fields.items(): + if value is not None: + metadata[get_kagent_metadata_key(key)] = serialize_metadata_value(value) + + return metadata diff --git a/python/uv.lock b/python/uv.lock index ca72911f6..12a342ace 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.13" resolution-markers = [ "python_full_version >= '3.14'", @@ -28,7 +28,7 @@ dev = [ [[package]] name = "a2a-sdk" -version = "0.3.3" +version = "0.3.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -37,9 +37,9 @@ dependencies = [ { name = "protobuf" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f0/33/25ddb456829784575b5c693699162f7fa930d90a82e16fefaa53d8a02154/a2a_sdk-0.3.3.tar.gz", hash = "sha256:d32426a819d3305116d24e938787b0028c4240df862e7fc2c7bf039def0d60e2", size = 219695, upload-time = "2025-08-25T18:32:09.8Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/0b/80671e784f61b55ac4c340d125d121ba91eba58ad7ba0f03b53b3831cd32/a2a_sdk-0.3.9.tar.gz", hash = "sha256:1dff7b5b1cab0b221519d0faed50176e200a1a87a8de8b64308d876505cc7c77", size = 224528, upload-time = "2025-10-15T17:35:28.299Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/cb/03315694ea2f9536c796767da0a1dd6d0b7c48df4badbab19c2c9c570741/a2a_sdk-0.3.3-py3-none-any.whl", hash = "sha256:d433c7f13a7a4b426e3aef798f9ef6eff0d3aea68c9a595634af865111b5c7c2", size = 135159, upload-time = "2025-08-25T18:32:08.18Z" }, + { url = "https://files.pythonhosted.org/packages/34/ee/53b2da6d2768b136f996b8c6ab00ebcc44852f9a33816a64deaca6b279fe/a2a_sdk-0.3.9-py3-none-any.whl", hash = "sha256:7ed03a915bae98def46ea0313786da0a7a488346c3dc8af88407bb0b2a763926", size = 139027, upload-time = "2025-10-15T17:35:26.628Z" }, ] [package.optional-dependencies] @@ -1893,7 +1893,7 @@ name = "kagent-core" version = "0.1.0" source = { editable = "packages/kagent-core" } dependencies = [ - { name = "a2a-sdk" }, + { name = "a2a-sdk", extra = ["http-server"] }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-instrumentation-anthropic" }, @@ -1906,7 +1906,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", specifier = ">=0.2.16" }, + { name = "a2a-sdk", extras = ["http-server"], specifier = ">=0.3.9" }, { name = "opentelemetry-api", specifier = ">=1.36.0" }, { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = ">=1.36.0" }, { name = "opentelemetry-instrumentation-anthropic", specifier = ">=0.44.0" },