diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 4898bc0b..cc84cf54 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -9,6 +9,7 @@ CLIConnectionError, CLIJSONDecodeError, CLINotFoundError, + HookCallbackError, ProcessError, ) from ._internal.transport import Transport @@ -360,6 +361,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "ClaudeSDKError", "CLIConnectionError", "CLINotFoundError", + "HookCallbackError", "ProcessError", "CLIJSONDecodeError", ] diff --git a/src/claude_agent_sdk/_errors.py b/src/claude_agent_sdk/_errors.py index c86bf235..269d6728 100644 --- a/src/claude_agent_sdk/_errors.py +++ b/src/claude_agent_sdk/_errors.py @@ -54,3 +54,13 @@ class MessageParseError(ClaudeSDKError): def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) + + +class HookCallbackError(ClaudeSDKError): + """Raised when a hook callback returns an invalid value.""" + + def __init__(self, message: str, callback_id: str | None = None): + self.callback_id = callback_id + if callback_id: + message = f"{message} (callback_id: {callback_id})" + super().__init__(message) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 6bf5a73c..f3d765f7 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -14,6 +14,7 @@ ListToolsRequest, ) +from .._errors import HookCallbackError from ..types import ( PermissionResultAllow, PermissionResultDeny, @@ -290,6 +291,21 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: request_data.get("tool_use_id"), {"signal": None}, # TODO: Add abort signal support ) + + # Validate hook callback return type + if hook_output is None: + raise HookCallbackError( + "Hook callback returned None. Expected a dict " + "(HookJSONOutput). Did you forget to return a value?", + callback_id=callback_id, + ) + if not isinstance(hook_output, dict): + raise HookCallbackError( + f"Hook callback returned {type(hook_output).__name__}, " + f"expected dict (HookJSONOutput)", + callback_id=callback_id, + ) + # Convert Python-safe field names (async_, continue_) to CLI-expected names (async, continue) response_data = _convert_hook_output_for_cli(hook_output) diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py index 8ace3c8d..a6e5b089 100644 --- a/tests/test_tool_callbacks.py +++ b/tests/test_tool_callbacks.py @@ -486,3 +486,192 @@ async def my_hook( assert "tool_use_start" in options.hooks assert len(options.hooks["tool_use_start"]) == 1 assert options.hooks["tool_use_start"][0].hooks[0] == my_hook + + +class TestHookCallbackReturnValidation: + """Test validation of hook callback return values.""" + + @pytest.mark.asyncio + async def test_hook_callback_returns_none_raises_error(self): + """Test that returning None from a hook callback raises HookCallbackError.""" + + async def bad_hook( + input_data: HookInput, tool_use_id: str | None, context: HookContext + ): + # Oops, forgot to return a value + pass + + transport = MockTransport() + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={} + ) + + callback_id = "test_none_hook" + query.hook_callbacks[callback_id] = bad_hook + + request = { + "type": "control_request", + "request_id": "test-none-return", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": None, + }, + } + + await query._handle_control_request(request) + + # Check that an error response was sent + assert len(transport.written_messages) > 0 + last_response = json.loads(transport.written_messages[-1]) + assert last_response["response"]["subtype"] == "error" + assert "returned None" in last_response["response"]["error"] + assert "callback_id: test_none_hook" in last_response["response"]["error"] + + @pytest.mark.asyncio + async def test_hook_callback_returns_string_raises_error(self): + """Test that returning a string from a hook callback raises HookCallbackError.""" + + async def bad_hook( + input_data: HookInput, tool_use_id: str | None, context: HookContext + ): + return "this is not a dict" + + transport = MockTransport() + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={} + ) + + callback_id = "test_string_hook" + query.hook_callbacks[callback_id] = bad_hook + + request = { + "type": "control_request", + "request_id": "test-string-return", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": None, + }, + } + + await query._handle_control_request(request) + + # Check that an error response was sent + assert len(transport.written_messages) > 0 + last_response = json.loads(transport.written_messages[-1]) + assert last_response["response"]["subtype"] == "error" + assert "returned str" in last_response["response"]["error"] + assert "expected dict" in last_response["response"]["error"] + + @pytest.mark.asyncio + async def test_hook_callback_returns_list_raises_error(self): + """Test that returning a list from a hook callback raises HookCallbackError.""" + + async def bad_hook( + input_data: HookInput, tool_use_id: str | None, context: HookContext + ): + return ["not", "a", "dict"] + + transport = MockTransport() + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={} + ) + + callback_id = "test_list_hook" + query.hook_callbacks[callback_id] = bad_hook + + request = { + "type": "control_request", + "request_id": "test-list-return", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": None, + }, + } + + await query._handle_control_request(request) + + # Check that an error response was sent + assert len(transport.written_messages) > 0 + last_response = json.loads(transport.written_messages[-1]) + assert last_response["response"]["subtype"] == "error" + assert "returned list" in last_response["response"]["error"] + + @pytest.mark.asyncio + async def test_hook_callback_valid_dict_succeeds(self): + """Test that returning a valid dict from a hook callback succeeds.""" + + async def good_hook( + input_data: HookInput, tool_use_id: str | None, context: HookContext + ) -> HookJSONOutput: + return {"continue_": True} + + transport = MockTransport() + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={} + ) + + callback_id = "test_valid_hook" + query.hook_callbacks[callback_id] = good_hook + + request = { + "type": "control_request", + "request_id": "test-valid-return", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": None, + }, + } + + await query._handle_control_request(request) + + # Check that a success response was sent (not an error) + assert len(transport.written_messages) > 0 + last_response = json.loads(transport.written_messages[-1]) + assert last_response["response"]["subtype"] == "success" + # The converted output should have "continue" (not "continue_") + assert last_response["response"]["response"]["continue"] is True + + @pytest.mark.asyncio + async def test_hook_callback_empty_dict_succeeds(self): + """Test that returning an empty dict from a hook callback succeeds.""" + + async def minimal_hook( + input_data: HookInput, tool_use_id: str | None, context: HookContext + ) -> HookJSONOutput: + return {} + + transport = MockTransport() + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={} + ) + + callback_id = "test_empty_hook" + query.hook_callbacks[callback_id] = minimal_hook + + request = { + "type": "control_request", + "request_id": "test-empty-return", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": None, + }, + } + + await query._handle_control_request(request) + + # Check that a success response was sent (not an error) + assert len(transport.written_messages) > 0 + last_response = json.loads(transport.written_messages[-1]) + assert last_response["response"]["subtype"] == "success" + # Empty dict is valid - response should contain an empty dict + assert last_response["response"]["response"] == {}