Skip to content

Commit

Permalink
Refactor openapi support from Tool
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Dec 10, 2024
1 parent 5464281 commit e1712f8
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 140 deletions.
152 changes: 152 additions & 0 deletions haystack_experimental/components/tools/openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union

from haystack.lazy_imports import LazyImport
from haystack.logging import logging

from haystack_experimental.dataclasses import Tool

with LazyImport(message="Run 'pip install openapi-llm'") as openapi_llm_import:
from openapi_llm.client.config import ClientConfig
from openapi_llm.client.openapi import OpenAPIClient
from openapi_llm.core.spec import OpenAPISpecification


logger = logging.getLogger(__name__)


class OpenAPIKwargs(TypedDict, total=False):
"""
TypedDict for OpenAPI configuration kwargs.
Contains all supported configuration options for Tool.from_openapi_spec()
"""

credentials: Any # API credentials (e.g., API key, auth token)
request_sender: Callable[[Dict[str, Any]], Dict[str, Any]] # Custom HTTPrequest sender function
allowed_operations: List[str] # A list of operations to include in the OpenAPI client.


def create_tools_from_openapi_spec(spec: Union[str, Path], **kwargs: OpenAPIKwargs) -> List["Tool"]:
"""
Create Tool instances from an OpenAPI specification.
The specification can be provided as:
- A URL pointing to an OpenAPI spec
- A local file path to an OpenAPI spec (JSON or YAML)
- A string containing the OpenAPI spec content (JSON or YAML)
:param spec: OpenAPI specification as URL, file path, or string content
:param kwargs: Additional configuration options for the OpenAPI client:
- credentials: API credentials (e.g., API key, auth token)
- request_sender: Custom callable to send HTTP requests
- allowed_operations: List of operations from the OpenAPI spec to include
:returns: List of Tool instances configured to invoke the OpenAPI service endpoints
:raises ValueError: If the OpenAPI specification is invalid or cannot be loaded
"""
openapi_llm_import.check()

# Load the OpenAPI specification
if isinstance(spec, str):
if spec.startswith(("http://", "https://")):
openapi_spec = OpenAPISpecification.from_url(spec)
elif Path(spec).exists():
openapi_spec = OpenAPISpecification.from_file(spec)
else:
openapi_spec = OpenAPISpecification.from_str(spec)
elif isinstance(spec, Path):
openapi_spec = OpenAPISpecification.from_file(str(spec))
else:
raise ValueError("spec must be a string (URL, file path, or content) or a Path object")

# Create client configuration
config = ClientConfig(openapi_spec=openapi_spec, **kwargs)

# Create an OpenAPI client for invocations
client = OpenAPIClient(config)

# Get all tool definitions from the config
tools = []
for llm_specific_tool_def in config.get_tool_definitions():
# Extract normalized tool definition
standardized_tool_def = _standardize_tool_definition(llm_specific_tool_def)
if not standardized_tool_def:
logger.warning(f"Skipping {llm_specific_tool_def}, as required parameters not found")
continue

# Create a closure that captures the current value of standardized_tool_def
def create_invoke_function(tool_def: Dict[str, Any]) -> Callable:
"""
Create an invoke function with the tool definition bound to its scope.
:param tool_def: The tool definition to bind to the invoke function.
:returns: Function that invokes the OpenAPI endpoint.
"""

def invoke_openapi(**kwargs):
"""
Invoke the OpenAPI endpoint with the provided arguments.
:param kwargs: Arguments to pass to the OpenAPI endpoint.
:returns: Response from the OpenAPI endpoint.
"""
return client.invoke({"name": tool_def["name"], "arguments": kwargs})

return invoke_openapi

tools.append(
Tool(
name=standardized_tool_def["name"],
description=standardized_tool_def["description"],
parameters=standardized_tool_def["parameters"],
function=create_invoke_function(standardized_tool_def),
)
)

return tools


def _standardize_tool_definition(llm_specific_tool_def: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Recursively extract tool parameters from different LLM provider formats.
Supports various LLM provider formats including OpenAI, Anthropic, and Cohere.
:param llm_specific_tool_def: Dictionary containing tool definition in provider-specific format
:returns: Dictionary with normalized tool parameters or None if required fields not found
"""
# Mapping of provider-specific schema field names to our Tool "parameters" field
SCHEMA_FIELD_NAMES = [
"parameters", # Cohere/OpenAI
"input_schema", # Anthropic
# any other field names that might contain a schema in other providers
]

def _find_in_dict(d: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if all(k in d for k in ["name", "description"]):
schema = None
for field_name in SCHEMA_FIELD_NAMES:
if field_name in d:
schema = d[field_name]
break

if schema is not None:
return {
"name": d["name"],
"description": d["description"],
"parameters": schema,
}

# Recurse into nested dictionaries
for v in d.values():
if isinstance(v, dict):
result = _find_in_dict(v)
if result:
return result
return None

return _find_in_dict(llm_specific_tool_def)
130 changes: 1 addition & 129 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,16 @@

import inspect
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict

from haystack.lazy_imports import LazyImport
from haystack.logging import logging
from haystack.utils import deserialize_callable, serialize_callable
from pydantic import create_model

from haystack_experimental.dataclasses.types import OpenAPIKwargs

logger = logging.getLogger(__name__)

with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
from jsonschema import Draft202012Validator
from jsonschema.exceptions import SchemaError

with LazyImport(message="Run 'pip install openapi-llm'") as openapi_llm_import:
from openapi_llm.client.config import ClientConfig
from openapi_llm.client.openapi import OpenAPIClient
from openapi_llm.core.spec import OpenAPISpecification


class ToolInvocationError(Exception):
"""
Expand Down Expand Up @@ -207,85 +196,6 @@ def get_weather(

return Tool(name=function.__name__, description=tool_description, parameters=schema, function=function)

@classmethod
def from_openapi_spec(cls, spec: Union[str, Path], **kwargs: OpenAPIKwargs) -> List["Tool"]:
"""
Create Tool instances from an OpenAPI specification.
The specification can be provided as:
- A URL pointing to an OpenAPI spec
- A local file path to an OpenAPI spec (JSON or YAML)
- A string containing the OpenAPI spec content (JSON or YAML)
:param spec: OpenAPI specification as URL, file path, or string content
:param kwargs: Additional configuration options for the OpenAPI client:
- credentials: API credentials (e.g., API key, auth token)
- request_sender: Custom callable to send HTTP requests
- allowed_operations: List of operations from the OpenAPI spec to include
:returns: List of Tool instances configured to invoke the OpenAPI service endpoints
:raises ValueError: If the OpenAPI specification is invalid or cannot be loaded
"""
openapi_llm_import.check()

# Load the OpenAPI specification
if isinstance(spec, str):
if spec.startswith(("http://", "https://")):
openapi_spec = OpenAPISpecification.from_url(spec)
elif Path(spec).exists():
openapi_spec = OpenAPISpecification.from_file(spec)
else:
openapi_spec = OpenAPISpecification.from_str(spec)
elif isinstance(spec, Path):
openapi_spec = OpenAPISpecification.from_file(str(spec))
else:
raise ValueError("spec must be a string (URL, file path, or content) or a Path object")

# Create client configuration
config = ClientConfig(openapi_spec=openapi_spec, **kwargs)

# Create an OpenAPI client for invocations
client = OpenAPIClient(config)

# Get all tool definitions from the config
tools = []
for llm_specific_tool_def in config.get_tool_definitions():
# Extract normalized tool definition
standardized_tool_def = _standardize_tool_definition(llm_specific_tool_def)
if not standardized_tool_def:
logger.warning(f"Skipping {llm_specific_tool_def}, as required parameters not found")
continue

# Create a closure that captures the current value of standardized_tool_def
def create_invoke_function(tool_def: Dict[str, Any]) -> Callable:
"""
Create an invoke function with the tool definition bound to its scope.
:param tool_def: The tool definition to bind to the invoke function.
:returns: Function that invokes the OpenAPI endpoint.
"""

def invoke_openapi(**kwargs):
"""
Invoke the OpenAPI endpoint with the provided arguments.
:param kwargs: Arguments to pass to the OpenAPI endpoint.
:returns: Response from the OpenAPI endpoint.
"""
return client.invoke({"name": tool_def["name"], "arguments": kwargs})

return invoke_openapi

tools.append(
cls(
name=standardized_tool_def["name"],
description=standardized_tool_def["description"],
parameters=standardized_tool_def["parameters"],
function=create_invoke_function(standardized_tool_def),
)
)

return tools


def _remove_title_from_schema(schema: Dict[str, Any]):
"""
Expand Down Expand Up @@ -327,41 +237,3 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
deserialized_tools.append(Tool.from_dict(tool))

data[key] = deserialized_tools


def _standardize_tool_definition(llm_specific_tool_def: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Recursively extract tool parameters from different LLM provider formats.
Supports various LLM provider formats including OpenAI, Anthropic, and Cohere.
:param llm_specific_tool_def: Dictionary containing tool definition in provider-specific format
:returns: Dictionary with normalized tool parameters or None if required fields not found
"""
# Mapping of provider-specific schema field names to our Tool "parameters" field
SCHEMA_FIELD_NAMES = [
"parameters", # Cohere/OpenAI
"input_schema", # Anthropic
# any other field names that might contain a schema in other providers
]

def _find_in_dict(d: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if all(k in d for k in ["name", "description"]):
schema = None
for field_name in SCHEMA_FIELD_NAMES:
if field_name in d:
schema = d[field_name]
break

if schema is not None:
return {"name": d["name"], "description": d["description"], "parameters": schema}

# Recurse into nested dictionaries
for v in d.values():
if isinstance(v, dict):
result = _find_in_dict(v)
if result:
return result
return None

return _find_in_dict(llm_specific_tool_def)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
"""Tests for Tool class OpenAPI functionality."""

import os
from typing import List

import pytest

from haystack_experimental.dataclasses.tool import Tool
from haystack_experimental.dataclasses.tool import OpenAPIKwargs
from haystack_experimental.components.tools.openapi import OpenAPIKwargs, create_tools_from_openapi_spec


class TestToolOpenAPI:
Expand All @@ -36,7 +33,7 @@ def test_from_openapi_spec_serperdev(self):
assert serper_api_key is not None

# Direct kwargs usage
tools = Tool.from_openapi_spec(
tools = create_tools_from_openapi_spec(
spec="https://bit.ly/serperdev_openapi",
credentials=serper_api_key
)
Expand All @@ -49,7 +46,7 @@ def test_from_openapi_spec_serperdev(self):
config = OpenAPIKwargs(
credentials=serper_api_key
)
tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
assert len(tools) >= 1
tool = tools[0]
assert tool.name == "search"
Expand All @@ -59,7 +56,7 @@ def test_from_openapi_spec_serperdev(self):
config = OpenAPIKwargs(**{
"credentials": serper_api_key
})
tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
assert len(tools) >= 1
tool = tools[0]
assert tool.name == "search"
Expand All @@ -85,7 +82,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self):
assert serper_api_key is not None

# Direct kwargs usage
tools = Tool.from_openapi_spec(
tools = create_tools_from_openapi_spec(
spec="https://bit.ly/serperdev_openapi",
credentials=serper_api_key,
allowed_operations=["search"]
Expand All @@ -100,7 +97,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self):
credentials=serper_api_key,
allowed_operations=["search"]
)
tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
assert len(tools) >= 1
tool = tools[0]
assert tool.name == "search"
Expand All @@ -111,7 +108,7 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self):
"credentials": serper_api_key,
"allowed_operations": ["search"]
})
tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
assert len(tools) >= 1
tool = tools[0]
assert tool.name == "search"
Expand All @@ -122,5 +119,5 @@ def test_from_openapi_spec_serperdev_with_allowed_operations(self):
credentials=serper_api_key,
allowed_operations=["super-search"]
)
tools = Tool.from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
tools = create_tools_from_openapi_spec(spec="https://bit.ly/serperdev_openapi", **config)
assert len(tools) == 0

0 comments on commit e1712f8

Please sign in to comment.