From e1712f8227d9aa095e87670966909c4a45e7f6ed Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 13:13:17 +0100 Subject: [PATCH] Refactor openapi support from Tool --- .../components/tools/openapi.py | 152 ++++++++++++++++++ haystack_experimental/dataclasses/tool.py | 130 +-------------- .../tools}/test_tool_openapi.py | 19 +-- 3 files changed, 161 insertions(+), 140 deletions(-) create mode 100644 haystack_experimental/components/tools/openapi.py rename test/{dataclasses => components/tools}/test_tool_openapi.py (84%) diff --git a/haystack_experimental/components/tools/openapi.py b/haystack_experimental/components/tools/openapi.py new file mode 100644 index 00000000..90371d55 --- /dev/null +++ b/haystack_experimental/components/tools/openapi.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, TypedDict, Union + +from haystack.lazy_imports import LazyImport +from haystack.logging import logging + +from haystack_experimental.dataclasses import Tool + +with LazyImport(message="Run 'pip install openapi-llm'") as openapi_llm_import: + from openapi_llm.client.config import ClientConfig + from openapi_llm.client.openapi import OpenAPIClient + from openapi_llm.core.spec import OpenAPISpecification + + +logger = logging.getLogger(__name__) + + +class OpenAPIKwargs(TypedDict, total=False): + """ + TypedDict for OpenAPI configuration kwargs. + + Contains all supported configuration options for Tool.from_openapi_spec() + """ + + credentials: Any # API credentials (e.g., API key, auth token) + request_sender: Callable[[Dict[str, Any]], Dict[str, Any]] # Custom HTTPrequest sender function + allowed_operations: List[str] # A list of operations to include in the OpenAPI client. + + +def create_tools_from_openapi_spec(spec: Union[str, Path], **kwargs: OpenAPIKwargs) -> List["Tool"]: + """ + Create Tool instances from an OpenAPI specification. + + The specification can be provided as: + - A URL pointing to an OpenAPI spec + - A local file path to an OpenAPI spec (JSON or YAML) + - A string containing the OpenAPI spec content (JSON or YAML) + + :param spec: OpenAPI specification as URL, file path, or string content + :param kwargs: Additional configuration options for the OpenAPI client: + - credentials: API credentials (e.g., API key, auth token) + - request_sender: Custom callable to send HTTP requests + - allowed_operations: List of operations from the OpenAPI spec to include + :returns: List of Tool instances configured to invoke the OpenAPI service endpoints + :raises ValueError: If the OpenAPI specification is invalid or cannot be loaded + """ + openapi_llm_import.check() + + # Load the OpenAPI specification + if isinstance(spec, str): + if spec.startswith(("http://", "https://")): + openapi_spec = OpenAPISpecification.from_url(spec) + elif Path(spec).exists(): + openapi_spec = OpenAPISpecification.from_file(spec) + else: + openapi_spec = OpenAPISpecification.from_str(spec) + elif isinstance(spec, Path): + openapi_spec = OpenAPISpecification.from_file(str(spec)) + else: + raise ValueError("spec must be a string (URL, file path, or content) or a Path object") + + # Create client configuration + config = ClientConfig(openapi_spec=openapi_spec, **kwargs) + + # Create an OpenAPI client for invocations + client = OpenAPIClient(config) + + # Get all tool definitions from the config + tools = [] + for llm_specific_tool_def in config.get_tool_definitions(): + # Extract normalized tool definition + standardized_tool_def = _standardize_tool_definition(llm_specific_tool_def) + if not standardized_tool_def: + logger.warning(f"Skipping {llm_specific_tool_def}, as required parameters not found") + continue + + # Create a closure that captures the current value of standardized_tool_def + def create_invoke_function(tool_def: Dict[str, Any]) -> Callable: + """ + Create an invoke function with the tool definition bound to its scope. + + :param tool_def: The tool definition to bind to the invoke function. + :returns: Function that invokes the OpenAPI endpoint. + """ + + def invoke_openapi(**kwargs): + """ + Invoke the OpenAPI endpoint with the provided arguments. + + :param kwargs: Arguments to pass to the OpenAPI endpoint. + :returns: Response from the OpenAPI endpoint. + """ + return client.invoke({"name": tool_def["name"], "arguments": kwargs}) + + return invoke_openapi + + tools.append( + Tool( + name=standardized_tool_def["name"], + description=standardized_tool_def["description"], + parameters=standardized_tool_def["parameters"], + function=create_invoke_function(standardized_tool_def), + ) + ) + + return tools + + +def _standardize_tool_definition(llm_specific_tool_def: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Recursively extract tool parameters from different LLM provider formats. + + Supports various LLM provider formats including OpenAI, Anthropic, and Cohere. + + :param llm_specific_tool_def: Dictionary containing tool definition in provider-specific format + :returns: Dictionary with normalized tool parameters or None if required fields not found + """ + # Mapping of provider-specific schema field names to our Tool "parameters" field + SCHEMA_FIELD_NAMES = [ + "parameters", # Cohere/OpenAI + "input_schema", # Anthropic + # any other field names that might contain a schema in other providers + ] + + def _find_in_dict(d: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if all(k in d for k in ["name", "description"]): + schema = None + for field_name in SCHEMA_FIELD_NAMES: + if field_name in d: + schema = d[field_name] + break + + if schema is not None: + return { + "name": d["name"], + "description": d["description"], + "parameters": schema, + } + + # Recurse into nested dictionaries + for v in d.values(): + if isinstance(v, dict): + result = _find_in_dict(v) + if result: + return result + return None + + return _find_in_dict(llm_specific_tool_def) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 098e1cd6..4f28f184 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -4,27 +4,16 @@ import inspect from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict from haystack.lazy_imports import LazyImport -from haystack.logging import logging from haystack.utils import deserialize_callable, serialize_callable from pydantic import create_model -from haystack_experimental.dataclasses.types import OpenAPIKwargs - -logger = logging.getLogger(__name__) - with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator from jsonschema.exceptions import SchemaError -with LazyImport(message="Run 'pip install openapi-llm'") as openapi_llm_import: - from openapi_llm.client.config import ClientConfig - from openapi_llm.client.openapi import OpenAPIClient - from openapi_llm.core.spec import OpenAPISpecification - class ToolInvocationError(Exception): """ @@ -207,85 +196,6 @@ def get_weather( return Tool(name=function.__name__, description=tool_description, parameters=schema, function=function) - @classmethod - def from_openapi_spec(cls, spec: Union[str, Path], **kwargs: OpenAPIKwargs) -> List["Tool"]: - """ - Create Tool instances from an OpenAPI specification. - - The specification can be provided as: - - A URL pointing to an OpenAPI spec - - A local file path to an OpenAPI spec (JSON or YAML) - - A string containing the OpenAPI spec content (JSON or YAML) - - :param spec: OpenAPI specification as URL, file path, or string content - :param kwargs: Additional configuration options for the OpenAPI client: - - credentials: API credentials (e.g., API key, auth token) - - request_sender: Custom callable to send HTTP requests - - allowed_operations: List of operations from the OpenAPI spec to include - :returns: List of Tool instances configured to invoke the OpenAPI service endpoints - :raises ValueError: If the OpenAPI specification is invalid or cannot be loaded - """ - openapi_llm_import.check() - - # Load the OpenAPI specification - if isinstance(spec, str): - if spec.startswith(("http://", "https://")): - openapi_spec = OpenAPISpecification.from_url(spec) - elif Path(spec).exists(): - openapi_spec = OpenAPISpecification.from_file(spec) - else: - openapi_spec = OpenAPISpecification.from_str(spec) - elif isinstance(spec, Path): - openapi_spec = OpenAPISpecification.from_file(str(spec)) - else: - raise ValueError("spec must be a string (URL, file path, or content) or a Path object") - - # Create client configuration - config = ClientConfig(openapi_spec=openapi_spec, **kwargs) - - # Create an OpenAPI client for invocations - client = OpenAPIClient(config) - - # Get all tool definitions from the config - tools = [] - for llm_specific_tool_def in config.get_tool_definitions(): - # Extract normalized tool definition - standardized_tool_def = _standardize_tool_definition(llm_specific_tool_def) - if not standardized_tool_def: - logger.warning(f"Skipping {llm_specific_tool_def}, as required parameters not found") - continue - - # Create a closure that captures the current value of standardized_tool_def - def create_invoke_function(tool_def: Dict[str, Any]) -> Callable: - """ - Create an invoke function with the tool definition bound to its scope. - - :param tool_def: The tool definition to bind to the invoke function. - :returns: Function that invokes the OpenAPI endpoint. - """ - - def invoke_openapi(**kwargs): - """ - Invoke the OpenAPI endpoint with the provided arguments. - - :param kwargs: Arguments to pass to the OpenAPI endpoint. - :returns: Response from the OpenAPI endpoint. - """ - return client.invoke({"name": tool_def["name"], "arguments": kwargs}) - - return invoke_openapi - - tools.append( - cls( - name=standardized_tool_def["name"], - description=standardized_tool_def["description"], - parameters=standardized_tool_def["parameters"], - function=create_invoke_function(standardized_tool_def), - ) - ) - - return tools - def _remove_title_from_schema(schema: Dict[str, Any]): """ @@ -327,41 +237,3 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): deserialized_tools.append(Tool.from_dict(tool)) data[key] = deserialized_tools - - -def _standardize_tool_definition(llm_specific_tool_def: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """ - Recursively extract tool parameters from different LLM provider formats. - - Supports various LLM provider formats including OpenAI, Anthropic, and Cohere. - - :param llm_specific_tool_def: Dictionary containing tool definition in provider-specific format - :returns: Dictionary with normalized tool parameters or None if required fields not found - """ - # Mapping of provider-specific schema field names to our Tool "parameters" field - SCHEMA_FIELD_NAMES = [ - "parameters", # Cohere/OpenAI - "input_schema", # Anthropic - # any other field names that might contain a schema in other providers - ] - - def _find_in_dict(d: Dict[str, Any]) -> Optional[Dict[str, Any]]: - if all(k in d for k in ["name", "description"]): - schema = None - for field_name in SCHEMA_FIELD_NAMES: - if field_name in d: - schema = d[field_name] - break - - if schema is not None: - return {"name": d["name"], "description": d["description"], "parameters": schema} - - # Recurse into nested dictionaries - for v in d.values(): - if isinstance(v, dict): - result = _find_in_dict(v) - if result: - return result - return None - - return _find_in_dict(llm_specific_tool_def) diff --git a/test/dataclasses/test_tool_openapi.py b/test/components/tools/test_tool_openapi.py similarity index 84% rename from test/dataclasses/test_tool_openapi.py rename to test/components/tools/test_tool_openapi.py index 8a0b4c91..4d2a4abb 100644 --- a/test/dataclasses/test_tool_openapi.py +++ b/test/components/tools/test_tool_openapi.py @@ -5,12 +5,9 @@ """Tests for Tool class OpenAPI functionality.""" import os -from typing import List - import pytest -from haystack_experimental.dataclasses.tool import Tool -from haystack_experimental.dataclasses.tool import OpenAPIKwargs +from haystack_experimental.components.tools.openapi import OpenAPIKwargs, create_tools_from_openapi_spec class TestToolOpenAPI: @@ -36,7 +33,7 @@ def test_from_openapi_spec_serperdev(self): assert serper_api_key is not None # Direct kwargs usage - tools = Tool.from_openapi_spec( + tools = create_tools_from_openapi_spec( spec="https://bit.ly/serperdev_openapi", credentials=serper_api_key ) @@ -49,7 +46,7 @@ def test_from_openapi_spec_serperdev(self): config = OpenAPIKwargs( credentials=serper_api_key ) - tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) + tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) assert len(tools) >= 1 tool = tools[0] assert tool.name == "search" @@ -59,7 +56,7 @@ def test_from_openapi_spec_serperdev(self): config = OpenAPIKwargs(**{ "credentials": serper_api_key }) - tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) + tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) assert len(tools) >= 1 tool = tools[0] assert tool.name == "search" @@ -85,7 +82,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self): assert serper_api_key is not None # Direct kwargs usage - tools = Tool.from_openapi_spec( + tools = create_tools_from_openapi_spec( spec="https://bit.ly/serperdev_openapi", credentials=serper_api_key, allowed_operations=["search"] @@ -100,7 +97,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self): credentials=serper_api_key, allowed_operations=["search"] ) - tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) + tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) assert len(tools) >= 1 tool = tools[0] assert tool.name == "search" @@ -111,7 +108,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self): "credentials": serper_api_key, "allowed_operations": ["search"] }) - tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) + tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) assert len(tools) >= 1 tool = tools[0] assert tool.name == "search" @@ -122,5 +119,5 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self): credentials=serper_api_key, allowed_operations=["super-search"] ) - tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) + tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config) assert len(tools) == 0