From 25551db3b323c294188d3efdd122ac4f99be709e Mon Sep 17 00:00:00 2001 From: va Date: Fri, 28 Feb 2025 10:51:42 -0500 Subject: [PATCH 1/5] feat: add runcontext + retryable + emitter to tool Signed-off-by: va --- .../agents/runners/default/runner.py | 2 +- python/beeai_framework/tools/mcp_tools.py | 3 +- .../tools/search/duckduckgo.py | 8 ++ .../beeai_framework/tools/search/wikipedia.py | 11 +++ python/beeai_framework/tools/tool.py | 76 ++++++++++++++++++- .../tools/weather/openmeteo.py | 11 +++ python/docs/tools.md | 32 ++++++-- python/examples/tools/advanced.py | 2 +- python/examples/tools/base.py | 2 +- python/examples/tools/custom/base.py | 13 +++- python/examples/tools/custom/openlibrary.py | 13 +++- python/examples/tools/wikipedia.py | 2 +- python/tests/tools/test_duckduckgo.py | 10 ++- python/tests/tools/test_opemmeteo.py | 46 +++++++---- python/tests/tools/test_wikipedia.py | 25 +++--- 15 files changed, 211 insertions(+), 45 deletions(-) diff --git a/python/beeai_framework/agents/runners/default/runner.py b/python/beeai_framework/agents/runners/default/runner.py index b73bb4f63..9d2209968 100644 --- a/python/beeai_framework/agents/runners/default/runner.py +++ b/python/beeai_framework/agents/runners/default/runner.py @@ -241,7 +241,7 @@ async def executor(_: RetryableContext) -> BeeRunnerToolResult: # tool_options = copy.copy(self._options) # TODO Tool run is not async # Convert tool input to dict - tool_output: ToolOutput = tool.run(input.state.tool_input, options={}) # TODO: pass tool options + tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options return BeeRunnerToolResult(output=tool_output, success=True) # TODO These error templates should be customized to help the LLM to recover except ToolInputValidationError as e: diff --git a/python/beeai_framework/tools/mcp_tools.py b/python/beeai_framework/tools/mcp_tools.py index eb0f198f0..aefa41ae0 100644 --- a/python/beeai_framework/tools/mcp_tools.py +++ b/python/beeai_framework/tools/mcp_tools.py @@ -57,11 +57,12 @@ class MCPTool(Tool[MCPToolOutput]): def __init__(self, client: ClientSession, tool: MCPToolInfo, **options: int) -> None: """Initialize MCPTool with client and tool configuration.""" - super().__init__(options) self.client = client self._tool = tool self._name = tool.name self._description = tool.description or "No available description, use the tool based on its name and schema." + + super().__init__(options) self.emitter = Emitter.root().child( EmitterInput( namespace=["tool", "mcp", self._name], diff --git a/python/beeai_framework/tools/search/duckduckgo.py b/python/beeai_framework/tools/search/duckduckgo.py index 307c260e4..61b42f2f0 100644 --- a/python/beeai_framework/tools/search/duckduckgo.py +++ b/python/beeai_framework/tools/search/duckduckgo.py @@ -18,6 +18,8 @@ from duckduckgo_search import DDGS from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolError from beeai_framework.tools.search import SearchToolOutput, SearchToolResult from beeai_framework.tools.tool import Tool @@ -53,6 +55,12 @@ def __init__(self, max_results: int = 10, safe_search: str = DuckDuckGoSearchTyp super().__init__() self.max_results = max_results self.safe_search = safe_search + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "search", "duckduckgo"], + creator=self, + ) + ) def _run(self, input: DuckDuckGoSearchToolInput, _: Any | None = None) -> DuckDuckGoSearchToolOutput: try: diff --git a/python/beeai_framework/tools/search/wikipedia.py b/python/beeai_framework/tools/search/wikipedia.py index bed111ad0..690e42e26 100644 --- a/python/beeai_framework/tools/search/wikipedia.py +++ b/python/beeai_framework/tools/search/wikipedia.py @@ -18,6 +18,8 @@ import wikipediaapi from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.search import SearchToolOutput, SearchToolResult from beeai_framework.tools.tool import Tool @@ -47,6 +49,15 @@ class WikipediaTool(Tool[WikipediaToolInput]): user_agent="beeai-framework https://github.com/i-am-bee/beeai-framework", language="en" ) + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "search", "wikipedia"], + creator=self, + ) + ) + def get_section_titles(self, sections: wikipediaapi.WikipediaPage.sections) -> str: titles = [] for section in sections: diff --git a/python/beeai_framework/tools/tool.py b/python/beeai_framework/tools/tool.py index 14ed9f0ee..635bb9f39 100644 --- a/python/beeai_framework/tools/tool.py +++ b/python/beeai_framework/tools/tool.py @@ -14,13 +14,19 @@ import inspect +import re from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any, Generic, TypeVar from pydantic import BaseModel, ConfigDict, ValidationError, create_model -from beeai_framework.tools.errors import ToolInputValidationError +from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput +from beeai_framework.errors import FrameworkError +from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput +from beeai_framework.tools.errors import ToolError, ToolInputValidationError from beeai_framework.utils import BeeLogger logger = BeeLogger(__name__) @@ -56,6 +62,8 @@ def get_text_content(self) -> str: class Tool(Generic[T], ABC): options: dict[str, Any] + emitter: Emitter + def __init__(self, options: dict[str, Any] | None = None) -> None: if options is None: options = {} @@ -93,8 +101,59 @@ def prompt_data(self) -> dict[str, str]: "input_schema": str(self.input_schema.model_json_schema(mode="serialization")), } - def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Any: - return self._run(self.validate_input(input), options) + def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[Any]: + async def run_tool(context: RunContext) -> Any: + error_propagated = False + + try: + validated_input = self.validate_input(input) + + meta = {"input": validated_input, "options": options} + + async def executor(_: RetryableContext) -> Any: + nonlocal error_propagated + error_propagated = False + await context.emitter.emit("start", meta) + return self._run(validated_input, options) + + async def on_error(error: Exception, _: RetryableContext) -> None: + nonlocal error_propagated + error_propagated = True + err = FrameworkError.ensure(error) + await context.emitter.emit("error", {"error": err, **meta}) + if err.is_fatal: + raise err from None + + async def on_retry(ctx: RetryableContext, last_error: Exception) -> None: + err = ToolError.ensure(last_error) + await context.emitter.emit("retry", {"error": err, **meta}) + + output = await Retryable( + RetryableInput( + executor=executor, + on_error=on_error, + on_retry=on_retry, + config=RetryableConfig( + max_retries=options.get("max_retries") if options else 1, signal=context.signal + ), + ) + ).get() + + await context.emitter.emit("success", {"output": output, **meta}) + return output + except Exception as e: + err = ToolError.ensure(e) + if not error_propagated: + await context.emitter.emit("error", {"error": err, "input": input, "options": options}) + raise err from None + finally: + await context.emitter.emit("finish", None) + + return RunContext.enter( + RunInstance(emitter=self.emitter), + RunContextInput(params=[input, options], signal=options.signal if options else None), + run_tool, + ) # this method was inspired by the discussion that was had in this issue: @@ -137,6 +196,17 @@ class FunctionTool(Tool): description = tool_description input_schema = tool_input + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + # replace any non-alphanumeric char with _ + formatted_name = re.sub(r"\W+", "_", self.name).lower() + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "custom", formatted_name], + creator=self, + ) + ) + def _run(self, tool_in: Any, _: dict[str, Any] | None = None) -> None: tool_input_dict = tool_in.model_dump() return tool_function(**tool_input_dict) diff --git a/python/beeai_framework/tools/weather/openmeteo.py b/python/beeai_framework/tools/weather/openmeteo.py index 36ad3b10c..4f51fbf98 100644 --- a/python/beeai_framework/tools/weather/openmeteo.py +++ b/python/beeai_framework/tools/weather/openmeteo.py @@ -22,6 +22,8 @@ import requests from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import StringToolOutput, Tool from beeai_framework.utils import BeeLogger @@ -48,6 +50,15 @@ class OpenMeteoTool(Tool[OpenMeteoToolInput]): description = "Retrieve current, past, or future weather forecasts for a location." input_schema = OpenMeteoToolInput + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "weather", "openmeteo"], + creator=self, + ) + ) + def _geocode(self, input: OpenMeteoToolInput) -> dict[str, Any]: params = {"format": "json", "count": 1} if input.location_name: diff --git a/python/docs/tools.md b/python/docs/tools.md index 84e4260b0..2baf17c20 100644 --- a/python/docs/tools.md +++ b/python/docs/tools.md @@ -34,7 +34,7 @@ from beeai_framework.tools.weather.openmeteo import OpenMeteoTool, OpenMeteoTool async def main() -> None: tool = OpenMeteoTool() - result = tool.run( + result = await tool.run( input=OpenMeteoToolInput(location_name="New York", start_date="2025-01-01", end_date="2025-01-02") ) print(result.get_text_content()) @@ -58,7 +58,7 @@ from beeai_framework.tools.weather.openmeteo import OpenMeteoTool, OpenMeteoTool async def main() -> None: tool = OpenMeteoTool() - result = tool.run( + result = await tool.run( input=OpenMeteoToolInput( location_name="New York", start_date="2025-01-01", end_date="2025-01-02", temperature_unit="celsius" ) @@ -235,7 +235,7 @@ from beeai_framework.tools.search.wikipedia import ( async def main() -> None: wikipedia_client = WikipediaTool({"full_text": True}) input = WikipediaToolInput(query="bee") - result = wikipedia_client.run(input) + result = await wikipedia_client.run(input) print(result.get_text_content()) @@ -262,6 +262,8 @@ from typing import Any from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.tool import Tool @@ -284,6 +286,15 @@ class RiddleTool(Tool[RiddleToolInput]): "What goes up but never comes down?", ) + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "example", "riddle"], + creator=self, + ) + ) + def _run(self, input: RiddleToolInput, _: Any | None = None) -> None: index = input.riddle_number % (len(self.data)) riddle = self.data[index] @@ -293,7 +304,7 @@ class RiddleTool(Tool[RiddleToolInput]): async def main() -> None: tool = RiddleTool() input = RiddleToolInput(riddle_number=random.randint(0, len(RiddleTool.data))) - result = tool.run(input) + result = await tool.run(input) print(result) @@ -324,6 +335,8 @@ from typing import Any import requests from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import Tool @@ -346,6 +359,15 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]): authors, contributors, publication dates, publisher and isbn.""" input_schema = OpenLibraryToolInput + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "example", "openlibrary"], + creator=self, + ) + ) + def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult: key = "" value = "" @@ -375,7 +397,7 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]): async def main() -> None: tool = OpenLibraryTool() input = OpenLibraryToolInput(title="It") - result = tool.run(input) + result = await tool.run(input) print(result) diff --git a/python/examples/tools/advanced.py b/python/examples/tools/advanced.py index d2bdf7271..b0f1f241c 100644 --- a/python/examples/tools/advanced.py +++ b/python/examples/tools/advanced.py @@ -5,7 +5,7 @@ async def main() -> None: tool = OpenMeteoTool() - result = tool.run( + result = await tool.run( input=OpenMeteoToolInput( location_name="New York", start_date="2025-01-01", end_date="2025-01-02", temperature_unit="celsius" ) diff --git a/python/examples/tools/base.py b/python/examples/tools/base.py index 1ca5145be..945b98c24 100644 --- a/python/examples/tools/base.py +++ b/python/examples/tools/base.py @@ -5,7 +5,7 @@ async def main() -> None: tool = OpenMeteoTool() - result = tool.run( + result = await tool.run( input=OpenMeteoToolInput(location_name="New York", start_date="2025-01-01", end_date="2025-01-02") ) print(result.get_text_content()) diff --git a/python/examples/tools/custom/base.py b/python/examples/tools/custom/base.py index 49a4f02a6..79b3eaf8b 100644 --- a/python/examples/tools/custom/base.py +++ b/python/examples/tools/custom/base.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.tool import Tool @@ -26,6 +28,15 @@ class RiddleTool(Tool[RiddleToolInput]): "What goes up but never comes down?", ) + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "example", "riddle"], + creator=self, + ) + ) + def _run(self, input: RiddleToolInput, _: Any | None = None) -> None: index = input.riddle_number % (len(self.data)) riddle = self.data[index] @@ -35,7 +46,7 @@ def _run(self, input: RiddleToolInput, _: Any | None = None) -> None: async def main() -> None: tool = RiddleTool() input = RiddleToolInput(riddle_number=random.randint(0, len(RiddleTool.data))) - result = tool.run(input) + result = await tool.run(input) print(result) diff --git a/python/examples/tools/custom/openlibrary.py b/python/examples/tools/custom/openlibrary.py index f270c675d..790aa4ddd 100644 --- a/python/examples/tools/custom/openlibrary.py +++ b/python/examples/tools/custom/openlibrary.py @@ -4,6 +4,8 @@ import requests from pydantic import BaseModel, Field +from beeai_framework.emitter.emitter import Emitter +from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import Tool @@ -26,6 +28,15 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]): authors, contributors, publication dates, publisher and isbn.""" input_schema = OpenLibraryToolInput + def __init__(self, options: dict[str, Any] | None = None) -> None: + super().__init__(options) + self.emitter = Emitter.root().child( + EmitterInput( + namespace=["tool", "example", "openlibrary"], + creator=self, + ) + ) + def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult: key = "" value = "" @@ -55,7 +66,7 @@ def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibrary async def main() -> None: tool = OpenLibraryTool() input = OpenLibraryToolInput(title="It") - result = tool.run(input) + result = await tool.run(input) print(result) diff --git a/python/examples/tools/wikipedia.py b/python/examples/tools/wikipedia.py index 0ee1e849d..46f1685cc 100644 --- a/python/examples/tools/wikipedia.py +++ b/python/examples/tools/wikipedia.py @@ -9,7 +9,7 @@ async def main() -> None: wikipedia_client = WikipediaTool({"full_text": True}) input = WikipediaToolInput(query="bee") - result = wikipedia_client.run(input) + result = await wikipedia_client.run(input) print(result.get_text_content()) diff --git a/python/tests/tools/test_duckduckgo.py b/python/tests/tools/test_duckduckgo.py index e6dae79f6..4a8db8851 100644 --- a/python/tests/tools/test_duckduckgo.py +++ b/python/tests/tools/test_duckduckgo.py @@ -38,9 +38,10 @@ def tool() -> DuckDuckGoSearchTool: @pytest.mark.unit -def test_call_invalid_input_type(tool: DuckDuckGoSearchTool) -> None: +@pytest.mark.asyncio +async def test_call_invalid_input_type(tool: DuckDuckGoSearchTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input={"search": "Poland"}) + await tool.run(input={"search": "Poland"}) """ @@ -49,7 +50,8 @@ def test_call_invalid_input_type(tool: DuckDuckGoSearchTool) -> None: @pytest.mark.e2e -def test_output(tool: DuckDuckGoSearchTool) -> None: - result = tool.run(input=DuckDuckGoSearchToolInput(query="What is the area of the Poland?")) +@pytest.mark.asyncio +async def test_output(tool: DuckDuckGoSearchTool) -> None: + result = await tool.run(input=DuckDuckGoSearchToolInput(query="What is the area of the Poland?")) assert type(result) is DuckDuckGoSearchToolOutput assert "322,575" in result.get_text_content() diff --git a/python/tests/tools/test_opemmeteo.py b/python/tests/tools/test_opemmeteo.py index 017c5d750..cea76750c 100644 --- a/python/tests/tools/test_opemmeteo.py +++ b/python/tests/tools/test_opemmeteo.py @@ -35,8 +35,9 @@ def tool() -> OpenMeteoTool: @pytest.mark.e2e -def test_call_model(tool: OpenMeteoTool) -> None: - tool.run( +@pytest.mark.asyncio +async def test_call_model(tool: OpenMeteoTool) -> None: + await tool.run( input=OpenMeteoToolInput( location_name="Cambridge", country="US", @@ -46,42 +47,55 @@ def test_call_model(tool: OpenMeteoTool) -> None: @pytest.mark.e2e -def test_call_dict(tool: OpenMeteoTool) -> None: - tool.run(input={"location_name": "White Plains"}) +@pytest.mark.asyncio +async def test_call_dict(tool: OpenMeteoTool) -> None: + await tool.run(input={"location_name": "White Plains"}) @pytest.mark.e2e -def test_call_invalid_missing_field(tool: OpenMeteoTool) -> None: +@pytest.mark.asyncio +async def test_call_invalid_missing_field(tool: OpenMeteoTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input={}) + await tool.run(input={}) @pytest.mark.e2e -def test_call_invalid_bad_type(tool: OpenMeteoTool) -> None: +@pytest.mark.asyncio +async def test_call_invalid_bad_type(tool: OpenMeteoTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input={"location_name": 1}) + await tool.run(input={"location_name": 1}) @pytest.mark.e2e -def test_output(tool: OpenMeteoTool) -> None: - result = tool.run(input={"location_name": "White Plains"}) +@pytest.mark.asyncio +async def test_output(tool: OpenMeteoTool) -> None: + result = await tool.run(input={"location_name": "White Plains"}) assert type(result) is StringToolOutput assert "current" in result.get_text_content() @pytest.mark.e2e -def test_bad_start_date_format(tool: OpenMeteoTool) -> None: +@pytest.mark.asyncio +async def test_bad_start_date_format(tool: OpenMeteoTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input=OpenMeteoToolInput(location_name="White Plains", start_date="2025:01:01", end_date="2025:01:02")) + await tool.run( + input=OpenMeteoToolInput(location_name="White Plains", start_date="2025:01:01", end_date="2025-01-02") + ) @pytest.mark.e2e -def test_bad_end_date_format(tool: OpenMeteoTool) -> None: +@pytest.mark.asyncio +async def test_bad_end_date_format(tool: OpenMeteoTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input=OpenMeteoToolInput(location_name="White Plains", start_date="2025-01-01", end_date="2025:01:02")) + await tool.run( + input=OpenMeteoToolInput(location_name="White Plains", start_date="2025-01-01", end_date="2025:01:02") + ) @pytest.mark.e2e -def test_bad_dates(tool: OpenMeteoTool) -> None: +@pytest.mark.asyncio +async def test_bad_dates(tool: OpenMeteoTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input=OpenMeteoToolInput(location_name="White Plains", start_date="2025-02-02", end_date="2025-02-01")) + await tool.run( + input=OpenMeteoToolInput(location_name="White Plains", start_date="2025-02-02", end_date="2025-02-01") + ) diff --git a/python/tests/tools/test_wikipedia.py b/python/tests/tools/test_wikipedia.py index 0dc23c1f8..0823373ec 100644 --- a/python/tests/tools/test_wikipedia.py +++ b/python/tests/tools/test_wikipedia.py @@ -38,34 +38,39 @@ def tool() -> WikipediaTool: @pytest.mark.e2e -def test_call_invalid_input_type(tool: WikipediaTool) -> None: +@pytest.mark.asyncio +async def test_call_invalid_input_type(tool: WikipediaTool) -> None: with pytest.raises(ToolInputValidationError): - tool.run(input={"search": "Bee"}) + await tool.run(input={"search": "Bee"}) @pytest.mark.e2e -def test_output(tool: WikipediaTool) -> None: - result = tool.run(input=WikipediaToolInput(query="bee")) +@pytest.mark.asyncio +async def test_output(tool: WikipediaTool) -> None: + result = await tool.run(input=WikipediaToolInput(query="bee")) assert type(result) is WikipediaToolOutput assert "Bees are winged insects closely related to wasps and ants" in result.get_text_content() @pytest.mark.e2e -def test_full_text_output(tool: WikipediaTool) -> None: - result = tool.run(input=WikipediaToolInput(query="bee", full_text=True)) +@pytest.mark.asyncio +async def test_full_text_output(tool: WikipediaTool) -> None: + result = await tool.run(input=WikipediaToolInput(query="bee", full_text=True)) assert type(result) is WikipediaToolOutput assert "n-triscosane" in result.get_text_content() @pytest.mark.e2e -def test_section_titles(tool: WikipediaTool) -> None: - result = tool.run(input=WikipediaToolInput(query="bee", section_titles=True)) +@pytest.mark.asyncio +async def test_section_titles(tool: WikipediaTool) -> None: + result = await tool.run(input=WikipediaToolInput(query="bee", section_titles=True)) assert type(result) is WikipediaToolOutput assert "Characteristics" in result.get_text_content() @pytest.mark.e2e -def test_alternate_language(tool: WikipediaTool) -> None: - result = tool.run(input=WikipediaToolInput(query="bee", language="fr")) +@pytest.mark.asyncio +async def test_alternate_language(tool: WikipediaTool) -> None: + result = await tool.run(input=WikipediaToolInput(query="bee", language="fr")) assert type(result) is WikipediaToolOutput assert "Les abeilles (Anthophila) forment un clade d'insectes" in result.get_text_content() From cf87caed3509b88a3c17b3c78b54db51359ad716 Mon Sep 17 00:00:00 2001 From: va Date: Fri, 28 Feb 2025 11:23:20 -0500 Subject: [PATCH 2/5] chore: update from merge Signed-off-by: va --- python/beeai_framework/tools/search/duckduckgo.py | 7 ++----- python/beeai_framework/tools/search/wikipedia.py | 7 ++----- python/beeai_framework/tools/tool.py | 7 ++----- python/beeai_framework/tools/weather/openmeteo.py | 7 ++----- python/examples/tools/custom/base.py | 7 ++----- python/examples/tools/custom/openlibrary.py | 7 ++----- 6 files changed, 12 insertions(+), 30 deletions(-) diff --git a/python/beeai_framework/tools/search/duckduckgo.py b/python/beeai_framework/tools/search/duckduckgo.py index 61b42f2f0..0d940498b 100644 --- a/python/beeai_framework/tools/search/duckduckgo.py +++ b/python/beeai_framework/tools/search/duckduckgo.py @@ -19,7 +19,6 @@ from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolError from beeai_framework.tools.search import SearchToolOutput, SearchToolResult from beeai_framework.tools.tool import Tool @@ -56,10 +55,8 @@ def __init__(self, max_results: int = 10, safe_search: str = DuckDuckGoSearchTyp self.max_results = max_results self.safe_search = safe_search self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "search", "duckduckgo"], - creator=self, - ) + namespace=["tool", "search", "duckduckgo"], + creator=self, ) def _run(self, input: DuckDuckGoSearchToolInput, _: Any | None = None) -> DuckDuckGoSearchToolOutput: diff --git a/python/beeai_framework/tools/search/wikipedia.py b/python/beeai_framework/tools/search/wikipedia.py index 690e42e26..8f82412ce 100644 --- a/python/beeai_framework/tools/search/wikipedia.py +++ b/python/beeai_framework/tools/search/wikipedia.py @@ -19,7 +19,6 @@ from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.search import SearchToolOutput, SearchToolResult from beeai_framework.tools.tool import Tool @@ -52,10 +51,8 @@ class WikipediaTool(Tool[WikipediaToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "search", "wikipedia"], - creator=self, - ) + namespace=["tool", "search", "wikipedia"], + creator=self, ) def get_section_titles(self, sections: wikipediaapi.WikipediaPage.sections) -> str: diff --git a/python/beeai_framework/tools/tool.py b/python/beeai_framework/tools/tool.py index 635bb9f39..3b37b6f2c 100644 --- a/python/beeai_framework/tools/tool.py +++ b/python/beeai_framework/tools/tool.py @@ -23,7 +23,6 @@ from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.errors import FrameworkError from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools.errors import ToolError, ToolInputValidationError @@ -201,10 +200,8 @@ def __init__(self, options: dict[str, Any] | None = None) -> None: # replace any non-alphanumeric char with _ formatted_name = re.sub(r"\W+", "_", self.name).lower() self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "custom", formatted_name], - creator=self, - ) + namespace=["tool", "custom", formatted_name], + creator=self, ) def _run(self, tool_in: Any, _: dict[str, Any] | None = None) -> None: diff --git a/python/beeai_framework/tools/weather/openmeteo.py b/python/beeai_framework/tools/weather/openmeteo.py index 4f51fbf98..075375fba 100644 --- a/python/beeai_framework/tools/weather/openmeteo.py +++ b/python/beeai_framework/tools/weather/openmeteo.py @@ -23,7 +23,6 @@ from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import StringToolOutput, Tool from beeai_framework.utils import BeeLogger @@ -53,10 +52,8 @@ class OpenMeteoTool(Tool[OpenMeteoToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "weather", "openmeteo"], - creator=self, - ) + namespace=["tool", "weather", "openmeteo"], + creator=self, ) def _geocode(self, input: OpenMeteoToolInput) -> dict[str, Any]: diff --git a/python/examples/tools/custom/base.py b/python/examples/tools/custom/base.py index 79b3eaf8b..66107728d 100644 --- a/python/examples/tools/custom/base.py +++ b/python/examples/tools/custom/base.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.tool import Tool @@ -31,10 +30,8 @@ class RiddleTool(Tool[RiddleToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "example", "riddle"], - creator=self, - ) + namespace=["tool", "example", "riddle"], + creator=self, ) def _run(self, input: RiddleToolInput, _: Any | None = None) -> None: diff --git a/python/examples/tools/custom/openlibrary.py b/python/examples/tools/custom/openlibrary.py index 790aa4ddd..9e3968436 100644 --- a/python/examples/tools/custom/openlibrary.py +++ b/python/examples/tools/custom/openlibrary.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import Tool @@ -31,10 +30,8 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "example", "openlibrary"], - creator=self, - ) + namespace=["tool", "example", "openlibrary"], + creator=self, ) def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult: From b4acb79ab4d872d97f64098a69fc03db9f8780c8 Mon Sep 17 00:00:00 2001 From: va Date: Fri, 28 Feb 2025 11:25:39 -0500 Subject: [PATCH 3/5] fix: feedback review Signed-off-by: va --- python/beeai_framework/agents/runners/default/runner.py | 3 --- python/beeai_framework/tools/mcp_tools.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/beeai_framework/agents/runners/default/runner.py b/python/beeai_framework/agents/runners/default/runner.py index 67e4edd63..7c214a40a 100644 --- a/python/beeai_framework/agents/runners/default/runner.py +++ b/python/beeai_framework/agents/runners/default/runner.py @@ -236,9 +236,6 @@ async def on_error(error: Exception, _: RetryableContext) -> None: async def executor(_: RetryableContext) -> BeeRunnerToolResult: try: - # tool_options = copy.copy(self._options) - # TODO Tool run is not async - # Convert tool input to dict tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options return BeeRunnerToolResult(output=tool_output, success=True) # TODO These error templates should be customized to help the LLM to recover diff --git a/python/beeai_framework/tools/mcp_tools.py b/python/beeai_framework/tools/mcp_tools.py index 12520d667..6c718f9f2 100644 --- a/python/beeai_framework/tools/mcp_tools.py +++ b/python/beeai_framework/tools/mcp_tools.py @@ -57,12 +57,12 @@ class MCPTool(Tool[MCPToolOutput]): def __init__(self, client: ClientSession, tool: MCPToolInfo, **options: int) -> None: """Initialize MCPTool with client and tool configuration.""" + super().__init__(options) self.client = client self._tool = tool self._name = tool.name self._description = tool.description or "No available description, use the tool based on its name and schema." - super().__init__(options) self.emitter = Emitter.root().child( namespace=["tool", "mcp", self._name], creator=self, From eb6ba1efde9256c4d4372184596c5e05af9b870f Mon Sep 17 00:00:00 2001 From: va Date: Fri, 28 Feb 2025 11:27:17 -0500 Subject: [PATCH 4/5] docs: update tools doc Signed-off-by: va --- python/docs/tools.md | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/docs/tools.md b/python/docs/tools.md index d7a7231ec..f40165323 100644 --- a/python/docs/tools.md +++ b/python/docs/tools.md @@ -259,7 +259,6 @@ from typing import Any from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools.tool import Tool @@ -285,10 +284,8 @@ class RiddleTool(Tool[RiddleToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "example", "riddle"], - creator=self, - ) + namespace=["tool", "example", "riddle"], + creator=self, ) def _run(self, input: RiddleToolInput, _: Any | None = None) -> None: @@ -332,7 +329,6 @@ import requests from pydantic import BaseModel, Field from beeai_framework.emitter.emitter import Emitter -from beeai_framework.emitter.types import EmitterInput from beeai_framework.tools import ToolInputValidationError from beeai_framework.tools.tool import Tool @@ -358,10 +354,8 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) self.emitter = Emitter.root().child( - EmitterInput( - namespace=["tool", "example", "openlibrary"], - creator=self, - ) + namespace=["tool", "example", "openlibrary"], + creator=self, ) def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult: From 53b029bcc960660a29eb76c924a51e7356b6b644 Mon Sep 17 00:00:00 2001 From: va Date: Fri, 28 Feb 2025 12:32:01 -0500 Subject: [PATCH 5/5] fix: feedback review Signed-off-by: va --- python/beeai_framework/tools/tool.py | 6 ++---- python/beeai_framework/utils/strings.py | 5 +++++ python/beeai_framework/workflows/workflow.py | 6 ++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/beeai_framework/tools/tool.py b/python/beeai_framework/tools/tool.py index 3b37b6f2c..0102bcc57 100644 --- a/python/beeai_framework/tools/tool.py +++ b/python/beeai_framework/tools/tool.py @@ -14,7 +14,6 @@ import inspect -import re from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any, Generic, TypeVar @@ -27,6 +26,7 @@ from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools.errors import ToolError, ToolInputValidationError from beeai_framework.utils import BeeLogger +from beeai_framework.utils.strings import to_safe_word logger = BeeLogger(__name__) @@ -197,10 +197,8 @@ class FunctionTool(Tool): def __init__(self, options: dict[str, Any] | None = None) -> None: super().__init__(options) - # replace any non-alphanumeric char with _ - formatted_name = re.sub(r"\W+", "_", self.name).lower() self.emitter = Emitter.root().child( - namespace=["tool", "custom", formatted_name], + namespace=["tool", "custom", to_safe_word(self.name)], creator=self, ) diff --git a/python/beeai_framework/utils/strings.py b/python/beeai_framework/utils/strings.py index fd039932f..b3581bc4e 100644 --- a/python/beeai_framework/utils/strings.py +++ b/python/beeai_framework/utils/strings.py @@ -46,3 +46,8 @@ def create_strenum(name: str, keys: Sequence[str]) -> type[StrEnum]: def to_json(input: Any, *, indent: int | None = None) -> str: return json.dumps(input, ensure_ascii=False, default=lambda o: o.__dict__, sort_keys=True, indent=indent) + + +def to_safe_word(phrase: str) -> str: + # replace any non-alphanumeric char with _ + return re.sub(r"\W+", "_", phrase).lower() diff --git a/python/beeai_framework/workflows/workflow.py b/python/beeai_framework/workflows/workflow.py index f1510564f..af3de91d8 100644 --- a/python/beeai_framework/workflows/workflow.py +++ b/python/beeai_framework/workflows/workflow.py @@ -14,7 +14,6 @@ import asyncio import inspect -import re from collections.abc import Awaitable, Callable from dataclasses import field from typing import Any, ClassVar, Final, Generic, Literal @@ -27,6 +26,7 @@ from beeai_framework.emitter.emitter import Emitter from beeai_framework.errors import FrameworkError from beeai_framework.utils.models import ModelLike, check_model, to_model, to_model_optional +from beeai_framework.utils.strings import to_safe_word from beeai_framework.utils.types import MaybeAsync from beeai_framework.workflows.errors import WorkflowError @@ -84,10 +84,8 @@ def __init__(self, schema: type[T], name: str = "Workflow") -> None: self._steps: dict[K, WorkflowStepDefinition[T, K]] = {} self._start_step: K | None = None - # replace any non-alphanumeric char with _ - formatted_name = re.sub(r"\W+", "_", self._name).lower() self.emitter = Emitter.root().child( - namespace=["workflow", formatted_name], + namespace=["workflow", to_safe_word(self._name)], creator=self, )