diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 3f55f718..abef0b73 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -19,8 +19,10 @@ from aiohttp import ClientSession from deprecated import deprecated -from .protocol import ManifestSchema, ToolSchema +from .itransport import ITransport +from .protocol import ToolSchema from .tool import ToolboxTool +from .toolbox_transport import ToolboxTransport from .utils import identify_auth_requirements, resolve_value @@ -33,9 +35,7 @@ class ToolboxClient: is not provided. """ - __base_url: str - __session: ClientSession - __manage_session: bool + __transport: ITransport def __init__( self, @@ -56,15 +56,8 @@ def __init__( should typically be managed externally. client_headers: Headers to include in each request sent through this client. """ - self.__base_url = url - - # If no aiohttp.ClientSession is provided, make our own - self.__manage_session = False - if session is None: - self.__manage_session = True - session = ClientSession() - self.__session = session + self.__transport = ToolboxTransport(url, session) self.__client_headers = client_headers if client_headers is not None else {} def __parse_tool( @@ -103,8 +96,7 @@ def __parse_tool( ) tool = ToolboxTool( - session=self.__session, - base_url=self.__base_url, + transport=self.__transport, name=name, description=schema.description, # create a read-only values to prevent mutation @@ -149,8 +141,7 @@ async def close(self): If the session was provided externally during initialization, the caller is responsible for its lifecycle. """ - if self.__manage_session and not self.__session.closed: - await self.__session.close() + await self.__transport.close() async def load_tool( self, @@ -191,16 +182,7 @@ async def load_tool( for name, val in self.__client_headers.items() } - # request the definition of the tool from the server - url = f"{self.__base_url}/api/tool/{name}" - async with self.__session.get(url, headers=resolved_headers) as response: - if not response.ok: - error_text = await response.text() - raise RuntimeError( - f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}" - ) - json = await response.json() - manifest: ManifestSchema = ManifestSchema(**json) + manifest = await self.__transport.tool_get(name, resolved_headers) # parse the provided definition to a tool if name not in manifest.tools: @@ -274,16 +256,8 @@ async def load_toolset( header_name: await resolve_value(original_headers[header_name]) for header_name in original_headers } - # Request the definition of the toolset from the server - url = f"{self.__base_url}/api/toolset/{name or ''}" - async with self.__session.get(url, headers=resolved_headers) as response: - if not response.ok: - error_text = await response.text() - raise RuntimeError( - f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}" - ) - json = await response.json() - manifest: ManifestSchema = ManifestSchema(**json) + + manifest = await self.__transport.tools_list(name, resolved_headers) tools: list[ToolboxTool] = [] overall_used_auth_keys: set[str] = set() diff --git a/packages/toolbox-core/src/toolbox_core/itransport.py b/packages/toolbox-core/src/toolbox_core/itransport.py new file mode 100644 index 00000000..0b38beb5 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/itransport.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Mapping, Optional + +from .protocol import ManifestSchema + + +class ITransport(ABC): + """Defines the contract for a 'smart' transport that handles both + protocol formatting and network communication. + """ + + @property + @abstractmethod + def base_url(self) -> str: + """The base URL for the transport.""" + pass + + @abstractmethod + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server.""" + pass + + @abstractmethod + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server.""" + pass + + @abstractmethod + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Mapping[str, str] + ) -> str: + """Invokes a specific tool on the server.""" + pass + + @abstractmethod + async def close(self): + """Closes any underlying connections.""" + pass diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 8d8e7825..0c72c39b 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -20,8 +20,7 @@ from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union from warnings import warn -from aiohttp import ClientSession - +from .itransport import ITransport from .protocol import ParameterSchema from .utils import ( create_func_docstring, @@ -46,8 +45,7 @@ class ToolboxTool: def __init__( self, - session: ClientSession, - base_url: str, + transport: ITransport, name: str, description: str, params: Sequence[ParameterSchema], @@ -68,8 +66,7 @@ def __init__( Toolbox server. Args: - session: The `aiohttp.ClientSession` used for making API requests. - base_url: The base URL of the Toolbox server API. + transport: The transport used for making API requests. name: The name of the remote tool. description: The description of the remote tool. params: The args of the tool. @@ -84,9 +81,7 @@ def __init__( client_headers: Client specific headers bound to the tool. """ # used to invoke the toolbox API - self.__session: ClientSession = session - self.__base_url: str = base_url - self.__url = f"{base_url}/api/tool/{name}/invoke" + self.__transport = transport self.__description = description self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) @@ -120,17 +115,6 @@ def __init__( # map of client headers to their value/callable/coroutine self.__client_headers = client_headers - # ID tokens contain sensitive user information (claims). Transmitting - # these over HTTP exposes the data to interception and unauthorized - # access. Always use HTTPS to ensure secure communication and protect - # user privacy. - if ( - required_authn_params or required_authz_tokens or client_headers - ) and not self.__url.startswith("https://"): - warn( - "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) - @property def _name(self) -> str: return self.__name__ @@ -171,8 +155,7 @@ def _client_headers( def __copy( self, - session: Optional[ClientSession] = None, - base_url: Optional[str] = None, + transport: Optional[ITransport] = None, name: Optional[str] = None, description: Optional[str] = None, params: Optional[Sequence[ParameterSchema]] = None, @@ -192,8 +175,7 @@ def __copy( Creates a copy of the ToolboxTool, overriding specific fields. Args: - session: The `aiohttp.ClientSession` used for making API requests. - base_url: The base URL of the Toolbox server API. + transport: The transport used for making API requests. name: The name of the remote tool. description: The description of the remote tool. params: The args of the tool. @@ -209,8 +191,7 @@ def __copy( """ check = lambda val, default: val if val is not None else default return ToolboxTool( - session=check(session, self.__session), - base_url=check(base_url, self.__base_url), + transport=check(transport, self.__transport), name=check(name, self.__name__), description=check(description, self.__description), params=check(params, self.__params), @@ -291,16 +272,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: token_getter ) - async with self.__session.post( - self.__url, - json=payload, - headers=headers, - ) as resp: - body = await resp.json() - if not resp.ok: - err = body.get("error", f"unexpected status from server: {resp.status}") - raise Exception(err) - return body.get("result", body) + return await self.__transport.tool_invoke( + self.__name__, + payload, + headers, + ) def add_auth_token_getters( self, diff --git a/packages/toolbox-core/src/toolbox_core/toolbox_transport.py b/packages/toolbox-core/src/toolbox_core/toolbox_transport.py new file mode 100644 index 00000000..0f1e7e40 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/toolbox_transport.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Mapping, Optional +from warnings import warn + +from aiohttp import ClientSession + +from .itransport import ITransport +from .protocol import ManifestSchema + + +class ToolboxTransport(ITransport): + """Transport for the native Toolbox protocol.""" + + def __init__(self, base_url: str, session: Optional[ClientSession]): + self.__base_url = base_url + + # If no aiohttp.ClientSession is provided, make our own + self.__manage_session = False + if session is not None: + self.__session = session + else: + self.__manage_session = True + self.__session = ClientSession() + + @property + def base_url(self) -> str: + """The base URL for the transport.""" + return self.__base_url + + async def __get_manifest( + self, url: str, headers: Optional[Mapping[str, str]] + ) -> ManifestSchema: + """Helper method to perform GET requests and parse the ManifestSchema.""" + async with self.__session.get(url, headers=headers) as response: + if not response.ok: + error_text = await response.text() + raise RuntimeError( + f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}" + ) + json = await response.json() + return ManifestSchema(**json) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + url = f"{self.__base_url}/api/tool/{tool_name}" + return await self.__get_manifest(url, headers) + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + url = f"{self.__base_url}/api/toolset/{toolset_name or ''}" + return await self.__get_manifest(url, headers) + + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Mapping[str, str] + ) -> str: + # ID tokens contain sensitive user information (claims). Transmitting + # these over HTTP exposes the data to interception and unauthorized + # access. Always use HTTPS to ensure secure communication and protect + # user privacy. + if self.base_url.startswith("http://") and headers: + warn( + "Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication." + ) + url = f"{self.__base_url}/api/tool/{tool_name}/invoke" + async with self.__session.post( + url, + json=arguments, + headers=headers, + ) as resp: + body = await resp.json() + if not resp.ok: + err = body.get("error", f"unexpected status from server: {resp.status}") + raise Exception(err) + return body.get("result") + + async def close(self): + if self.__manage_session and not self.__session.closed: + await self.__session.close() diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index bd27d3f2..7708c938 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -26,7 +26,9 @@ from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema -from toolbox_core.tool import ToolboxTool, create_func_docstring, resolve_value +from toolbox_core.tool import ToolboxTool +from toolbox_core.toolbox_transport import ToolboxTransport +from toolbox_core.utils import create_func_docstring, resolve_value TEST_BASE_URL = "http://toolbox.example.com" HTTPS_BASE_URL = "https://toolbox.example.com" @@ -110,9 +112,9 @@ def toolbox_tool( sample_tool_description: str, ) -> ToolboxTool: """Fixture for a ToolboxTool instance with common test setup.""" + transport = ToolboxTransport(TEST_BASE_URL, http_session) return ToolboxTool( - session=http_session, - base_url=TEST_BASE_URL, + transport=transport, name=TEST_TOOL_NAME, description=sample_tool_description, params=sample_tool_params, @@ -229,10 +231,10 @@ async def test_tool_creation_callable_and_run( with aioresponses() as m: m.post(invoke_url, status=200, payload=mock_server_response_body) + transport = ToolboxTransport(base_url, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=base_url, + transport=transport, name=tool_name, description=sample_tool_description, params=sample_tool_params, @@ -275,10 +277,10 @@ async def test_tool_run_with_pydantic_validation_error( with aioresponses() as m: m.post(invoke_url, status=200, payload={"result": "Should not be called"}) + transport = ToolboxTransport(base_url, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=base_url, + transport=transport, name=tool_name, description=sample_tool_description, params=sample_tool_params, @@ -366,10 +368,10 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti """Tests basic tool initialization without headers or auth.""" with catch_warnings(record=True) as record: simplefilter("always") + transport = ToolboxTransport(HTTPS_BASE_URL, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=HTTPS_BASE_URL, + transport=transport, name=TEST_TOOL_NAME, description=sample_tool_description, params=sample_tool_params, @@ -396,9 +398,9 @@ def test_tool_init_with_client_headers( http_session, sample_tool_params, sample_tool_description, static_client_header ): """Tests tool initialization *with* client headers.""" + transport = ToolboxTransport(HTTPS_BASE_URL, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=HTTPS_BASE_URL, + transport=transport, name=TEST_TOOL_NAME, description=sample_tool_description, params=sample_tool_params, @@ -420,9 +422,9 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header( Tests ValueError when add_auth_token_getters introduces an auth service whose token name conflicts with an existing client header. """ + transport = ToolboxTransport(HTTPS_BASE_URL, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=HTTPS_BASE_URL, + transport=transport, name="tool_with_client_header", description=sample_tool_description, params=sample_tool_params, @@ -456,40 +458,30 @@ async def test_auth_token_overrides_client_header( Tests that an auth token getter's value overrides a client header with the same name during the actual tool call. """ - - auth_service_name = "test-auth" - auth_header_key = f"{auth_service_name}_token" - auth_token_value = "value-from-auth-getter-123" - auth_getters = {auth_service_name: lambda: auth_token_value} - + transport = ToolboxTransport(HTTPS_BASE_URL, http_session) + tool_instance = ToolboxTool( + transport=transport, + name=TEST_TOOL_NAME, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + required_authz_tokens=[], + auth_service_token_getters={"test-auth": lambda: "value-from-auth-getter-123"}, + bound_params={}, + client_headers={ + "test-auth_token": "value-from-client", + "X-Another-Header": "another-value", + }, + ) tool_name = TEST_TOOL_NAME base_url = HTTPS_BASE_URL invoke_url = f"{base_url}/api/tool/{tool_name}/invoke" - client_headers = { - auth_header_key: "value-from-client", - "X-Another-Header": "another-value", - } - input_args = {"message": "test", "count": 1} mock_server_response = {"result": "Success"} with aioresponses() as m: m.post(invoke_url, status=200, payload=mock_server_response) - - tool_instance = ToolboxTool( - session=http_session, - base_url=base_url, - name=tool_name, - description=sample_tool_description, - params=sample_tool_params, - auth_service_token_getters=auth_getters, - client_headers=client_headers, - required_authn_params={}, - required_authz_tokens=[], - bound_params={}, - ) - # Call the tool await tool_instance(**input_args) @@ -498,7 +490,7 @@ async def test_auth_token_overrides_client_header( method="POST", json=input_args, headers={ - auth_header_key: auth_token_value, + "test-auth_token": "value-from-auth-getter-123", "X-Another-Header": "another-value", }, ) @@ -514,9 +506,9 @@ def test_add_auth_token_getter_unused_token( Tests ValueError when add_auth_token_getters is called with a getter for an unused authentication service. """ + transport = ToolboxTransport(HTTPS_BASE_URL, http_session) tool_instance = ToolboxTool( - session=http_session, - base_url=HTTPS_BASE_URL, + transport=transport, name=TEST_TOOL_NAME, description=sample_tool_description, params=sample_tool_params, @@ -534,219 +526,3 @@ def test_add_auth_token_getter_unused_token( next(iter(unused_auth_getters)), unused_auth_getters[next(iter(unused_auth_getters))], ) - - -def test_toolbox_tool_underscore_name_property(toolbox_tool: ToolboxTool): - """Tests the _name property.""" - assert toolbox_tool._name == TEST_TOOL_NAME - - -def test_toolbox_tool_underscore_description_property(toolbox_tool: ToolboxTool): - """Tests the _description property.""" - assert ( - toolbox_tool._description - == "A sample tool that processes a message and a count." - ) - - -def test_toolbox_tool_underscore_params_property( - toolbox_tool: ToolboxTool, sample_tool_params: list[ParameterSchema] -): - """Tests the _params property returns a deep copy.""" - params_copy = toolbox_tool._params - assert params_copy == sample_tool_params - assert ( - params_copy is not toolbox_tool._ToolboxTool__params - ) # Ensure it's a deepcopy - # Verify modifying the copy does not affect the original - params_copy.append( - ParameterSchema(name="new_param", type="integer", description="A new parameter") - ) - assert ( - len(toolbox_tool._ToolboxTool__params) == 2 - ) # Original should remain unchanged - - -def test_toolbox_tool_underscore_bound_params_property(toolbox_tool: ToolboxTool): - """Tests the _bound_params property returns an immutable MappingProxyType.""" - bound_params = toolbox_tool._bound_params - assert bound_params == {"fixed_param": "fixed_value"} - assert isinstance(bound_params, MappingProxyType) - # Verify immutability - with pytest.raises(TypeError): - bound_params["new_param"] = "new_value" - - -def test_toolbox_tool_underscore_required_authn_params_property( - toolbox_tool: ToolboxTool, -): - """Tests the _required_authn_params property returns an immutable MappingProxyType.""" - required_authn_params = toolbox_tool._required_authn_params - assert required_authn_params == {"message": ["service_a"]} - assert isinstance(required_authn_params, MappingProxyType) - # Verify immutability - with pytest.raises(TypeError): - required_authn_params["new_param"] = ["new_service"] - - -def test_toolbox_tool_underscore_required_authz_tokens_property( - toolbox_tool: ToolboxTool, -): - """Tests the _required_authz_tokens property returns an immutable MappingProxyType.""" - required_authz_tokens = toolbox_tool._required_authz_tokens - assert required_authz_tokens == ("service_b",) - assert isinstance(required_authz_tokens, tuple) - # Verify immutability - with pytest.raises(TypeError): - required_authz_tokens[0] = "new_service" - - -def test_toolbox_tool_underscore_auth_service_token_getters_property( - toolbox_tool: ToolboxTool, -): - """Tests the _auth_service_token_getters property returns an immutable MappingProxyType.""" - auth_getters = toolbox_tool._auth_service_token_getters - assert "service_x" in auth_getters - assert auth_getters["service_x"]() == "token_x" - assert isinstance(auth_getters, MappingProxyType) - # Verify immutability - with pytest.raises(TypeError): - auth_getters["new_service"] = lambda: "new_token" - - -def test_toolbox_tool_underscore_client_headers_property(toolbox_tool: ToolboxTool): - """Tests the _client_headers property returns an immutable MappingProxyType.""" - client_headers = toolbox_tool._client_headers - assert client_headers == {"X-Test-Client": "client_header_value"} - assert isinstance(client_headers, MappingProxyType) - # Verify immutability - with pytest.raises(TypeError): - client_headers["new_header"] = "new_value" - - -# --- Test for the HTTP Warning --- -@pytest.mark.parametrize( - "trigger_condition_params", - [ - {"client_headers": {"X-Some-Header": "value"}}, - {"required_authn_params": {"param1": ["auth-service1"]}}, - {"required_authz_tokens": ["auth-service2"]}, - { - "client_headers": {"X-Some-Header": "value"}, - "required_authn_params": {"param1": ["auth-service1"]}, - }, - { - "client_headers": {"X-Some-Header": "value"}, - "required_authz_tokens": ["auth-service2"], - }, - { - "required_authn_params": {"param1": ["auth-service1"]}, - "required_authz_tokens": ["auth-service2"], - }, - { - "client_headers": {"X-Some-Header": "value"}, - "required_authn_params": {"param1": ["auth-service1"]}, - "required_authz_tokens": ["auth-service2"], - }, - ], - ids=[ - "client_headers_only", - "authn_params_only", - "authz_tokens_only", - "headers_and_authn", - "headers_and_authz", - "authn_and_authz", - "all_three_conditions", - ], -) -def test_tool_init_http_warning_when_sensitive_info_over_http( - http_session: ClientSession, - sample_tool_params: list[ParameterSchema], - sample_tool_description: str, - trigger_condition_params: dict, -): - """ - Tests that a UserWarning is issued if client headers, auth params, or - auth tokens are present and the base_url is HTTP. - """ - expected_warning_message = ( - "Sending ID token over HTTP. User data may be exposed. " - "Use HTTPS for secure communication." - ) - - init_kwargs = { - "session": http_session, - "base_url": TEST_BASE_URL, - "name": "http_warning_tool", - "description": sample_tool_description, - "params": sample_tool_params, - "required_authn_params": {}, - "required_authz_tokens": [], - "auth_service_token_getters": {}, - "bound_params": {}, - "client_headers": {}, - } - # Apply the specific conditions for this parametrized test - init_kwargs.update(trigger_condition_params) - - with pytest.warns(UserWarning, match=expected_warning_message): - ToolboxTool(**init_kwargs) - - -def test_tool_init_no_http_warning_if_https( - http_session: ClientSession, - sample_tool_params: list[ParameterSchema], - sample_tool_description: str, - static_client_header: dict, -): - """ - Tests that NO UserWarning is issued if client headers are present but - the base_url is HTTPS. - """ - with catch_warnings(record=True) as record: - simplefilter("always") - - ToolboxTool( - session=http_session, - base_url=HTTPS_BASE_URL, - name="https_tool", - description=sample_tool_description, - params=sample_tool_params, - required_authn_params={}, - required_authz_tokens=[], - auth_service_token_getters={}, - bound_params={}, - client_headers=static_client_header, - ) - assert ( - len(record) == 0 - ), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}" - - -def test_tool_init_no_http_warning_if_no_sensitive_info_on_http( - http_session: ClientSession, - sample_tool_params: list[ParameterSchema], - sample_tool_description: str, -): - """ - Tests that NO UserWarning is issued if the URL is HTTP but there are - no client headers, auth params, or auth tokens. - """ - with catch_warnings(record=True) as record: - simplefilter("always") - - ToolboxTool( - session=http_session, - base_url=TEST_BASE_URL, - name="http_tool_no_sensitive", - description=sample_tool_description, - params=sample_tool_params, - required_authn_params={}, - required_authz_tokens=[], - auth_service_token_getters={}, - bound_params={}, - client_headers={}, - ) - assert ( - len(record) == 0 - ), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}" diff --git a/packages/toolbox-core/tests/test_toolbox_transport.py b/packages/toolbox-core/tests/test_toolbox_transport.py new file mode 100644 index 00000000..2921e09f --- /dev/null +++ b/packages/toolbox-core/tests/test_toolbox_transport.py @@ -0,0 +1,226 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import AsyncGenerator, Mapping, Optional, Union +from unittest.mock import AsyncMock + +import pytest +import pytest_asyncio +from aiohttp import ClientSession +from aioresponses import aioresponses + +from toolbox_core.protocol import ManifestSchema +from toolbox_core.toolbox_transport import ToolboxTransport + +TEST_BASE_URL = "http://fake-toolbox-server.com" +TEST_TOOL_NAME = "test_tool" + + +@pytest_asyncio.fixture +async def http_session() -> AsyncGenerator[ClientSession, None]: + """Provides a real aiohttp ClientSession that is closed after the test.""" + async with ClientSession() as session: + yield session + + +@pytest.fixture +def mock_manifest_dict() -> dict: + """Provides a valid sample dictionary for a ManifestSchema response.""" + tool_definition = { + "name": TEST_TOOL_NAME, + "description": "A test tool", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "The first parameter.", + "required": True, + } + ], + } + return { + "serverVersion": "1.0.0", + "tools": {TEST_TOOL_NAME: tool_definition}, + } + + +@pytest.mark.asyncio +async def test_base_url_property(http_session: ClientSession): + """Tests that the base_url property returns the correct URL.""" + transport = ToolboxTransport(TEST_BASE_URL, http_session) + assert transport.base_url == TEST_BASE_URL + + +@pytest.mark.asyncio +async def test_tool_get_success(http_session: ClientSession, mock_manifest_dict: dict): + """Tests a successful tool_get call.""" + url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}" + headers = {"X-Test-Header": "value"} + transport = ToolboxTransport(TEST_BASE_URL, http_session) + + with aioresponses() as m: + m.get(url, status=200, payload=mock_manifest_dict) + result = await transport.tool_get(TEST_TOOL_NAME, headers=headers) + + assert isinstance(result, ManifestSchema) + assert result.serverVersion == "1.0.0" + # FIX: Check for a valid attribute like 'description' instead of 'name' + assert result.tools[TEST_TOOL_NAME].description == "A test tool" + m.assert_called_once_with(url, headers=headers) + + +@pytest.mark.asyncio +async def test_tool_get_failure(http_session: ClientSession): + """Tests a failing tool_get call and ensures it raises RuntimeError.""" + url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}" + transport = ToolboxTransport(TEST_BASE_URL, http_session) + + with aioresponses() as m: + m.get(url, status=500, body="Internal Server Error") + with pytest.raises(RuntimeError) as exc_info: + await transport.tool_get(TEST_TOOL_NAME) + + assert "API request failed with status 500" in str(exc_info.value) + assert "Internal Server Error" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "toolset_name, expected_path", + [ + ("my_toolset", "/api/toolset/my_toolset"), + (None, "/api/toolset/"), + ], +) +async def test_tools_list_success( + http_session: ClientSession, + mock_manifest_dict: dict, + toolset_name: Union[str, None], + expected_path: str, +): + """Tests successful tools_list calls with and without a toolset name.""" + url = f"{TEST_BASE_URL}{expected_path}" + transport = ToolboxTransport(TEST_BASE_URL, http_session) + + with aioresponses() as m: + m.get(url, status=200, payload=mock_manifest_dict) + result = await transport.tools_list(toolset_name=toolset_name) + + assert isinstance(result, ManifestSchema) + # FIX: Add headers=None to match the actual call signature + m.assert_called_once_with(url, headers=None) + + +@pytest.mark.asyncio +async def test_tool_invoke_success(http_session: ClientSession): + """Tests a successful tool_invoke call.""" + url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}/invoke" + args = {"param1": "value1"} + headers = {"Authorization": "Bearer token"} + response_payload = {"result": "success"} + transport = ToolboxTransport(TEST_BASE_URL, http_session) + + with aioresponses() as m: + m.post(url, status=200, payload=response_payload) + result = await transport.tool_invoke(TEST_TOOL_NAME, args, headers) + + assert result == "success" + m.assert_called_once_with(url, method="POST", json=args, headers=headers) + + +@pytest.mark.asyncio +async def test_tool_invoke_failure(http_session: ClientSession): + """Tests a failing tool_invoke call where the server returns an error payload.""" + url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}/invoke" + response_payload = {"error": "Invalid arguments"} + transport = ToolboxTransport(TEST_BASE_URL, http_session) + + with aioresponses() as m: + m.post(url, status=400, payload=response_payload) + with pytest.raises(Exception) as exc_info: + await transport.tool_invoke(TEST_TOOL_NAME, {}, {}) + + assert str(exc_info.value) == "Invalid arguments" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "base_url, headers, should_warn", + [ + ( + "http://fake-toolbox-server.com", + {"Authorization": "Bearer token"}, + True, + ), + ( + "https://fake-toolbox-server.com", + {"Authorization": "Bearer token"}, + False, + ), + ("http://fake-toolbox-server.com", {}, False), + ("http://fake-toolbox-server.com", None, False), + ], +) +async def test_tool_invoke_http_warning( + http_session: ClientSession, + base_url: str, + headers: Optional[Mapping[str, str]], + should_warn: bool, +): + """Tests the HTTP security warning logic in tool_invoke.""" + url = f"{base_url}/api/tool/{TEST_TOOL_NAME}/invoke" + args = {"param1": "value1"} + response_payload = {"result": "success"} + transport = ToolboxTransport(base_url, http_session) + + with aioresponses() as m: + m.post(url, status=200, payload=response_payload) + + if should_warn: + with pytest.warns(UserWarning, match="Sending data token over HTTP"): + await transport.tool_invoke(TEST_TOOL_NAME, args, headers) + else: + # By not using pytest.warns, we assert that no warnings are raised. + # The test will fail if an unexpected UserWarning occurs. + await transport.tool_invoke(TEST_TOOL_NAME, args, headers) + + +@pytest.mark.asyncio +async def test_close_does_not_close_unmanaged_session(): + """ + Tests that close() does NOT affect a session that was provided externally + (i.e., an unmanaged session). + """ + mock_session = AsyncMock(spec=ClientSession) + mock_session.closed = False + + transport = ToolboxTransport(TEST_BASE_URL, mock_session) + await transport.close() + mock_session.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_closes_managed_session(): + """ + Tests that close() successfully closes a session that was created and + managed internally by the transport. + """ + transport = ToolboxTransport(TEST_BASE_URL, session=None) + # Access the internal session before closing to check its state + internal_session = transport._ToolboxTransport__session + assert internal_session.closed is False + + await transport.close() + internal_session = transport._ToolboxTransport__session + assert internal_session.closed is True diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index 96bd7660..84aeab52 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types # For MappingProxyType -from unittest.mock import AsyncMock, Mock, patch +import types +from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema as CoreParameterSchema from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.toolbox_transport import ToolboxTransport from toolbox_langchain.async_tools import AsyncToolboxTool @@ -67,9 +68,9 @@ def _create_core_tool_from_dict( else: tool_constructor_params.append(p_schema) + transport = ToolboxTransport(base_url=url, session=session) return ToolboxCoreTool( - session=session, - base_url=url, + transport=transport, name=name, description=schema_dict["description"], params=tool_constructor_params, @@ -86,10 +87,10 @@ def _create_core_tool_from_dict( @patch("aiohttp.ClientSession") async def toolbox_tool(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"result": "test-result"}) - mock_response.status = 200 # *** Fix: Set status for the mock response *** + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + mock_session.post.return_value.__aenter__.return_value.ok = True core_tool_instance = self._create_core_tool_from_dict( session=mock_session, @@ -104,10 +105,10 @@ async def toolbox_tool(self, MockClientSession, tool_schema_dict): @patch("aiohttp.ClientSession") async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"result": "test-result"}) - mock_response.status = 200 # *** Fix: Set status for the mock response *** + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + mock_session.post.return_value.__aenter__.return_value.ok = True core_tool_instance = self._create_core_tool_from_dict( session=mock_session, @@ -121,8 +122,6 @@ async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): @patch("aiohttp.ClientSession") async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.status = 200 core_tool_instance = self._create_core_tool_from_dict( session=mock_session, name="test_tool", @@ -173,8 +172,6 @@ async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool - # Verify that 'param1' is not in the list of bindable parameters for the core tool - # because it requires authentication. assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] with pytest.raises( ValueError, match="unable to bind parameters: no parameter named param1" @@ -227,19 +224,6 @@ async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( in str(excinfo.value) ) - valid_lambda = lambda: "test-token" - with pytest.raises(ValueError) as excinfo_mixed: - auth_toolbox_tool.add_auth_token_getters( - { - "test-auth-source": valid_lambda, - "another-auth-source": unused_lambda, - } - ) - assert ( - "Authentication source(s) `another-auth-source` unused by tool `test_tool`" - in str(excinfo_mixed.value) - ) - async def test_toolbox_tool_add_auth_token_getters_duplicate( self, auth_toolbox_tool ): @@ -263,7 +247,8 @@ async def test_toolbox_tool_call(self, toolbox_tool): result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) assert result == "test-result" core_tool = toolbox_tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, @@ -283,7 +268,8 @@ async def test_toolbox_tool_call_with_bound_params( result = await tool.ainvoke({"param2": 123}) assert result == "test-result" core_tool = tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, @@ -296,52 +282,13 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): result = await tool.ainvoke({"param2": 123}) assert result == "test-result" core_tool = tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure( - self, auth_toolbox_tool, auth_tool_schema_dict - ): - core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool - mock_session = core_tool_of_auth_tool._ToolboxTool__session - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - insecure_core_tool = self._create_core_tool_from_dict( - session=mock_session, - name="test_tool", - schema_dict=auth_tool_schema_dict, - url="http://test-url", - ) - - insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) - - tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool_with_getter.ainvoke({"param2": 123}) - assert result == "test-result" - - modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool - assert ( - modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" - ) - assert ( - modified_core_tool_in_new_tool._ToolboxTool__url - == "http://test-url/api/tool/test_tool/invoke" - ) - - modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, - ) - async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): with pytest.raises(ValidationError) as e: await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) diff --git a/packages/toolbox-llamaindex/tests/test_async_tools.py b/packages/toolbox-llamaindex/tests/test_async_tools.py index 251c88cd..2230d51e 100644 --- a/packages/toolbox-llamaindex/tests/test_async_tools.py +++ b/packages/toolbox-llamaindex/tests/test_async_tools.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types # For MappingProxyType -from unittest.mock import AsyncMock, Mock, patch +import types +from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio @@ -21,6 +21,7 @@ from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema as CoreParameterSchema from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.toolbox_transport import ToolboxTransport from toolbox_llamaindex.async_tools import AsyncToolboxTool @@ -68,9 +69,9 @@ def _create_core_tool_from_dict( else: tool_constructor_params.append(p_schema) + transport = ToolboxTransport(base_url=url, session=session) return ToolboxCoreTool( - session=session, - base_url=url, + transport=transport, name=name, description=schema_dict["description"], params=tool_constructor_params, @@ -87,10 +88,10 @@ def _create_core_tool_from_dict( @patch("aiohttp.ClientSession") async def toolbox_tool(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"result": "test-result"}) - mock_response.status = 200 # *** Fix: Set status for the mock response *** + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + mock_session.post.return_value.__aenter__.return_value.ok = True core_tool_instance = self._create_core_tool_from_dict( session=mock_session, @@ -105,11 +106,10 @@ async def toolbox_tool(self, MockClientSession, tool_schema_dict): @patch("aiohttp.ClientSession") async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"result": "test-result"}) - mock_response.status = 200 # *** Fix: Set status for the mock response *** - + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + mock_session.post.return_value.__aenter__.return_value.ok = True core_tool_instance = self._create_core_tool_from_dict( session=mock_session, name="test_tool", @@ -122,8 +122,6 @@ async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): @patch("aiohttp.ClientSession") async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_response = mock_session.post.return_value.__aenter__.return_value - mock_response.status = 200 core_tool_instance = self._create_core_tool_from_dict( session=mock_session, name="test_tool", @@ -174,8 +172,6 @@ async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool - # Verify that 'param1' is not in the list of bindable parameters for the core tool - # because it requires authentication. assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] with pytest.raises( ValueError, match="unable to bind parameters: no parameter named param1" @@ -228,19 +224,6 @@ async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( in str(excinfo.value) ) - valid_lambda = lambda: "test-token" - with pytest.raises(ValueError) as excinfo_mixed: - auth_toolbox_tool.add_auth_token_getters( - { - "test-auth-source": valid_lambda, - "another-auth-source": unused_lambda, - } - ) - assert ( - "Authentication source(s) `another-auth-source` unused by tool `test_tool`" - in str(excinfo_mixed.value) - ) - async def test_toolbox_tool_add_auth_token_getters_duplicate( self, auth_toolbox_tool ): @@ -267,11 +250,11 @@ async def test_toolbox_tool_call(self, toolbox_tool): tool_name="test_tool", raw_input={"param1": "test-value", "param2": 123}, raw_output="test-result", - is_error=False, ) core_tool = toolbox_tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, @@ -294,10 +277,10 @@ async def test_toolbox_tool_call_with_bound_params( tool_name="test_tool", raw_input={"param2": 123}, raw_output="test-result", - is_error=False, ) core_tool = tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, @@ -313,62 +296,16 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): tool_name="test_tool", raw_input={"param2": 123}, raw_output="test-result", - is_error=False, ) core_tool = tool._AsyncToolboxTool__core_tool - core_tool._ToolboxTool__session.post.assert_called_once_with( + session = core_tool._ToolboxTool__transport._ToolboxTransport__session + session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure( - self, auth_toolbox_tool, auth_tool_schema_dict - ): - core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool - mock_session = core_tool_of_auth_tool._ToolboxTool__session - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - insecure_core_tool = self._create_core_tool_from_dict( - session=mock_session, - name="test_tool", - schema_dict=auth_tool_schema_dict, - url="http://test-url", - ) - - insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) - - tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool_with_getter.acall(param2=123) - assert result == ToolOutput( - content="test-result", - tool_name="test_tool", - raw_input={"param2": 123}, - raw_output="test-result", - is_error=False, - ) - - modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool - assert ( - modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" - ) - assert ( - modified_core_tool_in_new_tool._ToolboxTool__url - == "http://test-url/api/tool/test_tool/invoke" - ) - - modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, - ) - async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): with pytest.raises(TypeError) as e: await toolbox_tool.acall()