Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add runcontext + retryable + emitter to tool #429

Merged
merged 7 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/beeai_framework/agents/runners/default/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ async def on_error(error: Exception, _: RetryableContext) -> None:

async def executor(_: RetryableContext) -> BeeRunnerToolResult:
try:
# tool_options = copy.copy(self._options)
# TODO Tool run is not async
# Convert tool input to dict
tool_output: ToolOutput = tool.run(input.state.tool_input, options={}) # TODO: pass tool options
tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options
return BeeRunnerToolResult(output=tool_output, success=True)
# TODO These error templates should be customized to help the LLM to recover
except ToolInputValidationError as e:
Expand Down
1 change: 1 addition & 0 deletions python/beeai_framework/tools/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, client: ClientSession, tool: MCPToolInfo, **options: int) ->
self._tool = tool
self._name = tool.name
self._description = tool.description or "No available description, use the tool based on its name and schema."

self.emitter = Emitter.root().child(
namespace=["tool", "mcp", self._name],
creator=self,
Expand Down
5 changes: 5 additions & 0 deletions python/beeai_framework/tools/search/duckduckgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from duckduckgo_search import DDGS
from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools import ToolError
from beeai_framework.tools.search import SearchToolOutput, SearchToolResult
from beeai_framework.tools.tool import Tool
Expand Down Expand Up @@ -53,6 +54,10 @@ def __init__(self, max_results: int = 10, safe_search: str = DuckDuckGoSearchTyp
super().__init__()
self.max_results = max_results
self.safe_search = safe_search
self.emitter = Emitter.root().child(
namespace=["tool", "search", "duckduckgo"],
creator=self,
)

def _run(self, input: DuckDuckGoSearchToolInput, _: Any | None = None) -> DuckDuckGoSearchToolOutput:
try:
Expand Down
8 changes: 8 additions & 0 deletions python/beeai_framework/tools/search/wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import wikipediaapi
from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools.search import SearchToolOutput, SearchToolResult
from beeai_framework.tools.tool import Tool

Expand Down Expand Up @@ -47,6 +48,13 @@ class WikipediaTool(Tool[WikipediaToolInput]):
user_agent="beeai-framework https://github.com/i-am-bee/beeai-framework", language="en"
)

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "search", "wikipedia"],
creator=self,
)

def get_section_titles(self, sections: wikipediaapi.WikipediaPage.sections) -> str:
titles = []
for section in sections:
Expand Down
71 changes: 68 additions & 3 deletions python/beeai_framework/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@

from pydantic import BaseModel, ConfigDict, ValidationError, create_model

from beeai_framework.tools.errors import ToolInputValidationError
from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.tools.errors import ToolError, ToolInputValidationError
from beeai_framework.utils import BeeLogger
from beeai_framework.utils.strings import to_safe_word

logger = BeeLogger(__name__)

Expand Down Expand Up @@ -56,6 +61,8 @@ def get_text_content(self) -> str:
class Tool(Generic[T], ABC):
options: dict[str, Any]

emitter: Emitter

def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
Expand Down Expand Up @@ -93,8 +100,59 @@ def prompt_data(self) -> dict[str, str]:
"input_schema": str(self.input_schema.model_json_schema(mode="serialization")),
}

def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Any:
return self._run(self.validate_input(input), options)
def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[Any]:
async def run_tool(context: RunContext) -> Any:
error_propagated = False

try:
validated_input = self.validate_input(input)

meta = {"input": validated_input, "options": options}

async def executor(_: RetryableContext) -> Any:
nonlocal error_propagated
error_propagated = False
await context.emitter.emit("start", meta)
return self._run(validated_input, options)

async def on_error(error: Exception, _: RetryableContext) -> None:
nonlocal error_propagated
error_propagated = True
err = FrameworkError.ensure(error)
await context.emitter.emit("error", {"error": err, **meta})
if err.is_fatal:
raise err from None

async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:
err = ToolError.ensure(last_error)
await context.emitter.emit("retry", {"error": err, **meta})

output = await Retryable(
RetryableInput(
executor=executor,
on_error=on_error,
on_retry=on_retry,
config=RetryableConfig(
max_retries=options.get("max_retries") if options else 1, signal=context.signal
),
)
).get()

await context.emitter.emit("success", {"output": output, **meta})
return output
except Exception as e:
err = ToolError.ensure(e)
if not error_propagated:
await context.emitter.emit("error", {"error": err, "input": input, "options": options})
raise err from None
finally:
await context.emitter.emit("finish", None)

return RunContext.enter(
RunInstance(emitter=self.emitter),
RunContextInput(params=[input, options], signal=options.signal if options else None),
run_tool,
)


# this method was inspired by the discussion that was had in this issue:
Expand Down Expand Up @@ -137,6 +195,13 @@ class FunctionTool(Tool):
description = tool_description
input_schema = tool_input

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "custom", to_safe_word(self.name)],
creator=self,
)

def _run(self, tool_in: Any, _: dict[str, Any] | None = None) -> None:
tool_input_dict = tool_in.model_dump()
return tool_function(**tool_input_dict)
Expand Down
8 changes: 8 additions & 0 deletions python/beeai_framework/tools/weather/openmeteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests
from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools import ToolInputValidationError
from beeai_framework.tools.tool import StringToolOutput, Tool
from beeai_framework.utils import BeeLogger
Expand All @@ -48,6 +49,13 @@ class OpenMeteoTool(Tool[OpenMeteoToolInput]):
description = "Retrieve current, past, or future weather forecasts for a location."
input_schema = OpenMeteoToolInput

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "weather", "openmeteo"],
creator=self,
)

def _geocode(self, input: OpenMeteoToolInput) -> dict[str, Any]:
params = {"format": "json", "count": 1}
if input.location_name:
Expand Down
5 changes: 5 additions & 0 deletions python/beeai_framework/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def create_strenum(name: str, keys: Sequence[str]) -> type[StrEnum]:

def to_json(input: Any, *, indent: int | None = None) -> str:
return json.dumps(input, ensure_ascii=False, default=lambda o: o.__dict__, sort_keys=True, indent=indent)


def to_safe_word(phrase: str) -> str:
# replace any non-alphanumeric char with _
return re.sub(r"\W+", "_", phrase).lower()
6 changes: 2 additions & 4 deletions python/beeai_framework/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import asyncio
import inspect
import re
from collections.abc import Awaitable, Callable
from dataclasses import field
from typing import Any, ClassVar, Final, Generic, Literal
Expand All @@ -27,6 +26,7 @@
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.utils.models import ModelLike, check_model, to_model, to_model_optional
from beeai_framework.utils.strings import to_safe_word
from beeai_framework.utils.types import MaybeAsync
from beeai_framework.workflows.errors import WorkflowError

Expand Down Expand Up @@ -84,10 +84,8 @@ def __init__(self, schema: type[T], name: str = "Workflow") -> None:
self._steps: dict[K, WorkflowStepDefinition[T, K]] = {}
self._start_step: K | None = None

# replace any non-alphanumeric char with _
formatted_name = re.sub(r"\W+", "_", self._name).lower()
self.emitter = Emitter.root().child(
namespace=["workflow", formatted_name],
namespace=["workflow", to_safe_word(self._name)],
creator=self,
)

Expand Down
26 changes: 21 additions & 5 deletions python/docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from beeai_framework.tools.weather.openmeteo import OpenMeteoTool, OpenMeteoTool

async def main() -> None:
tool = OpenMeteoTool()
result = tool.run(
result = await tool.run(
input=OpenMeteoToolInput(location_name="New York", start_date="2025-01-01", end_date="2025-01-02")
)
print(result.get_text_content())
Expand All @@ -58,7 +58,7 @@ from beeai_framework.tools.weather.openmeteo import OpenMeteoTool, OpenMeteoTool

async def main() -> None:
tool = OpenMeteoTool()
result = tool.run(
result = await tool.run(
input=OpenMeteoToolInput(
location_name="New York", start_date="2025-01-01", end_date="2025-01-02", temperature_unit="celsius"
)
Expand Down Expand Up @@ -231,7 +231,7 @@ from beeai_framework.tools.search.wikipedia import (
async def main() -> None:
wikipedia_client = WikipediaTool({"full_text": True})
input = WikipediaToolInput(query="bee")
result = wikipedia_client.run(input)
result = await wikipedia_client.run(input)
print(result.get_text_content())


Expand All @@ -258,6 +258,7 @@ from typing import Any

from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools.tool import Tool


Expand All @@ -280,6 +281,13 @@ class RiddleTool(Tool[RiddleToolInput]):
"What goes up but never comes down?",
)

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "example", "riddle"],
creator=self,
)

def _run(self, input: RiddleToolInput, _: Any | None = None) -> None:
index = input.riddle_number % (len(self.data))
riddle = self.data[index]
Expand All @@ -289,7 +297,7 @@ class RiddleTool(Tool[RiddleToolInput]):
async def main() -> None:
tool = RiddleTool()
input = RiddleToolInput(riddle_number=random.randint(0, len(RiddleTool.data)))
result = tool.run(input)
result = await tool.run(input)
print(result)


Expand Down Expand Up @@ -320,6 +328,7 @@ from typing import Any
import requests
from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools import ToolInputValidationError
from beeai_framework.tools.tool import Tool

Expand All @@ -342,6 +351,13 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]):
authors, contributors, publication dates, publisher and isbn."""
input_schema = OpenLibraryToolInput

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "example", "openlibrary"],
creator=self,
)

def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult:
key = ""
value = ""
Expand Down Expand Up @@ -371,7 +387,7 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]):
async def main() -> None:
tool = OpenLibraryTool()
input = OpenLibraryToolInput(title="It")
result = tool.run(input)
result = await tool.run(input)
print(result)


Expand Down
2 changes: 1 addition & 1 deletion python/examples/tools/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

async def main() -> None:
tool = OpenMeteoTool()
result = tool.run(
result = await tool.run(
input=OpenMeteoToolInput(
location_name="New York", start_date="2025-01-01", end_date="2025-01-02", temperature_unit="celsius"
)
Expand Down
2 changes: 1 addition & 1 deletion python/examples/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

async def main() -> None:
tool = OpenMeteoTool()
result = tool.run(
result = await tool.run(
input=OpenMeteoToolInput(location_name="New York", start_date="2025-01-01", end_date="2025-01-02")
)
print(result.get_text_content())
Expand Down
10 changes: 9 additions & 1 deletion python/examples/tools/custom/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools.tool import Tool


Expand All @@ -26,6 +27,13 @@ class RiddleTool(Tool[RiddleToolInput]):
"What goes up but never comes down?",
)

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "example", "riddle"],
creator=self,
)

def _run(self, input: RiddleToolInput, _: Any | None = None) -> None:
index = input.riddle_number % (len(self.data))
riddle = self.data[index]
Expand All @@ -35,7 +43,7 @@ def _run(self, input: RiddleToolInput, _: Any | None = None) -> None:
async def main() -> None:
tool = RiddleTool()
input = RiddleToolInput(riddle_number=random.randint(0, len(RiddleTool.data)))
result = tool.run(input)
result = await tool.run(input)
print(result)


Expand Down
10 changes: 9 additions & 1 deletion python/examples/tools/custom/openlibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
from pydantic import BaseModel, Field

from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools import ToolInputValidationError
from beeai_framework.tools.tool import Tool

Expand All @@ -26,6 +27,13 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput]):
authors, contributors, publication dates, publisher and isbn."""
input_schema = OpenLibraryToolInput

def __init__(self, options: dict[str, Any] | None = None) -> None:
super().__init__(options)
self.emitter = Emitter.root().child(
namespace=["tool", "example", "openlibrary"],
creator=self,
)

def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibraryToolResult:
key = ""
value = ""
Expand Down Expand Up @@ -55,7 +63,7 @@ def _run(self, input: OpenLibraryToolInput, _: Any | None = None) -> OpenLibrary
async def main() -> None:
tool = OpenLibraryTool()
input = OpenLibraryToolInput(title="It")
result = tool.run(input)
result = await tool.run(input)
print(result)


Expand Down
2 changes: 1 addition & 1 deletion python/examples/tools/wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
async def main() -> None:
wikipedia_client = WikipediaTool({"full_text": True})
input = WikipediaToolInput(query="bee")
result = wikipedia_client.run(input)
result = await wikipedia_client.run(input)
print(result.get_text_content())


Expand Down
Loading