Skip to content

Commit

Permalink
fix(agents): handle native tool calling and retries
Browse files Browse the repository at this point in the history
Ref: #441
Ref: #443
Ref: #428
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Mar 3, 2025
1 parent 571af1d commit 62db3dc
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 70 deletions.
16 changes: 15 additions & 1 deletion python/beeai_framework/adapters/litellm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,21 @@ def _transform_input(self, input: ChatModelInput) -> LiteLLMParameters:
else:
messages.append(message.to_plain())

tools = [{"type": "function", "function": tool.prompt_data()} for tool in input.tools] if input.tools else None
tools = (
[
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema.model_json_schema(mode="validation"),
},
}
for tool in input.tools
]
if input.tools
else None
)

return LiteLLMParameters(
model=f"{self._litellm_provider_id}/{self.model_id}",
Expand Down
9 changes: 5 additions & 4 deletions python/beeai_framework/agents/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from abc import ABC, abstractmethod
from dataclasses import dataclass

Expand Down Expand Up @@ -73,7 +72,9 @@ def __init__(self, input: BeeInput, options: BeeRunOptions, run: RunContext) ->
self._failed_attempts_counter: RetryCounter = RetryCounter(
error_type=AgentError,
max_retries=(
options.execution.total_max_retries if options.execution and options.execution.total_max_retries else 0
options.execution.total_max_retries
if options.execution and options.execution.total_max_retries
else math.inf
),
)
self._run = run
Expand All @@ -93,7 +94,7 @@ async def create_iteration(self) -> RunnerIteration:
max_iterations = (
self._options.execution.max_iterations
if self._options.execution and self._options.execution.max_iterations
else 0
else math.inf
)

if meta.iteration > max_iterations:
Expand Down
37 changes: 36 additions & 1 deletion python/beeai_framework/agents/runners/default/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
from beeai_framework.template import PromptTemplate, PromptTemplateInput


class UserEmptyPromptTemplateInput(BaseModel):
pass


class ToolNoResultsTemplateInput(BaseModel):
pass


class UserPromptTemplateInput(BaseModel):
input: str

Expand Down Expand Up @@ -49,6 +57,10 @@ class ToolInputErrorTemplateInput(BaseModel):
reason: str


class ToolErrorTemplateInput(BaseModel):
reason: str


class SchemaErrorTemplateInput(BaseModel):
pass

Expand Down Expand Up @@ -149,10 +161,33 @@ class SchemaErrorTemplateInput(BaseModel):
)
)

ToolNoResultsTemplate = PromptTemplate(
PromptTemplateInput(
schema=ToolNoResultsTemplateInput,
template="""No results were found!""",
)
)

UserEmptyPromptTemplate = PromptTemplate(
PromptTemplateInput(
schema=UserEmptyPromptTemplateInput,
template="""Message: Empty message.""",
)
)

ToolErrorTemplate = PromptTemplate(
PromptTemplateInput(
schema=ToolErrorTemplateInput,
template="""The function has failed; the error log is shown below. If the function cannot accomplish what you want, use a different function or explain why you can't use it.
{{&reason}}""", # noqa: E501
)
)

ToolInputErrorTemplate = PromptTemplate(
PromptTemplateInput(
schema=ToolInputErrorTemplateInput,
template="""{{reason}}
template="""{{&reason}}
HINT: If you're convinced that the input was correct but the function cannot process it then use a different function or say I don't know.""", # noqa: E501
)
Expand Down
47 changes: 28 additions & 19 deletions python/beeai_framework/agents/runners/default/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections.abc import Callable

from beeai_framework.agents.runners.base import (
Expand All @@ -28,8 +27,11 @@
SystemPromptTemplate,
SystemPromptTemplateInput,
ToolDefinition,
ToolErrorTemplate,
ToolInputErrorTemplate,
ToolNoResultsTemplate,
ToolNotFoundErrorTemplate,
UserEmptyPromptTemplate,
UserPromptTemplate,
)
from beeai_framework.agents.types import (
Expand All @@ -54,7 +56,7 @@
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.tools import ToolError, ToolInputValidationError
from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput
from beeai_framework.utils.strings import create_strenum
from beeai_framework.utils.strings import create_strenum, to_json


class DefaultRunner(BaseRunner):
Expand All @@ -65,7 +67,10 @@ def default_templates(self) -> BeeAgentTemplates:
system=SystemPromptTemplate,
assistant=AssistantPromptTemplate,
user=UserPromptTemplate,
user_empty=UserEmptyPromptTemplate,
tool_not_found_error=ToolNotFoundErrorTemplate,
tool_no_result_error=ToolNoResultsTemplate,
tool_error=ToolErrorTemplate,
tool_input_error=ToolInputErrorTemplate,
schema_error=SchemaErrorTemplate,
)
Expand Down Expand Up @@ -237,27 +242,27 @@ async def on_error(error: Exception, _: RetryableContext) -> None:
async def executor(_: RetryableContext) -> BeeRunnerToolResult:
try:
tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options
return BeeRunnerToolResult(output=tool_output, success=True)
# TODO These error templates should be customized to help the LLM to recover
except ToolInputValidationError as e:
self._failed_attempts_counter.use(e)
output = (
tool_output
if not tool_output.is_empty()
else StringToolOutput(self.templates.tool_no_result_error.render({}))
)
return BeeRunnerToolResult(
success=False,
output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})),
output=output,
success=True,
)

except ToolError as e:
except ToolInputValidationError as e:
self._failed_attempts_counter.use(e)

return BeeRunnerToolResult(
success=False,
output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})),
output=StringToolOutput(self.templates.tool_input_error.render({"reason": e.explain()})),
)
except json.JSONDecodeError as e:
self._failed_attempts_counter.use(e)
except Exception as e:
err = ToolError.ensure(e)
self._failed_attempts_counter.use(err)
return BeeRunnerToolResult(
success=False,
output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})),
output=StringToolOutput(self.templates.tool_error.render({"reason": err.explain()})),
)

if self._options and self._options.execution and self._options.execution.max_retries_per_step:
Expand All @@ -278,10 +283,14 @@ async def init_memory(self, input: BeeRunInput) -> BaseMemory:
capacity_threshold=0.85, sync_threshold=0.5, llm=self._input.llm
) # TODO handlers needs to be fixed

tool_defs = []

for tool in self._input.tools:
tool_defs.append(ToolDefinition(**tool.prompt_data()))
tool_defs = [
ToolDefinition(
name=tool.name,
description=tool.description,
input_schema=to_json(tool.input_schema.model_json_schema(mode="validation")),
)
for tool in self._input.tools
]

system_prompt: str = self.templates.system.render(
SystemPromptTemplateInput(
Expand Down
12 changes: 11 additions & 1 deletion python/beeai_framework/agents/runners/granite/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AssistantPromptTemplateInput,
SchemaErrorTemplateInput,
SystemPromptTemplateInput,
ToolErrorTemplateInput,
ToolInputErrorTemplateInput,
ToolNotFoundErrorTemplateInput,
UserPromptTemplateInput,
Expand Down Expand Up @@ -96,12 +97,21 @@
GraniteToolInputErrorTemplate = PromptTemplate(
PromptTemplateInput(
schema=ToolInputErrorTemplateInput,
template="""{{reason}}
template="""{{&reason}}
HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.""", # noqa: E501
)
)

GraniteToolErrorTemplate = PromptTemplate(
PromptTemplateInput(
schema=ToolErrorTemplateInput,
template="""The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it.
{{&reason}}""", # noqa: E501
)
)

GraniteSchemaErrorTemplate = PromptTemplate(
PromptTemplateInput(
schema=SchemaErrorTemplateInput,
Expand Down
7 changes: 5 additions & 2 deletions python/beeai_framework/agents/runners/granite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from beeai_framework.agents.runners.default.prompts import ToolNoResultsTemplate, UserEmptyPromptTemplate
from beeai_framework.agents.runners.default.runner import DefaultRunner
from beeai_framework.agents.runners.granite.prompts import (
GraniteAssistantPromptTemplate,
GraniteSchemaErrorTemplate,
GraniteSystemPromptTemplate,
GraniteToolErrorTemplate,
GraniteToolInputErrorTemplate,
GraniteToolNotFoundErrorTemplate,
GraniteUserPromptTemplate,
Expand Down Expand Up @@ -95,5 +95,8 @@ def default_templates(self) -> BeeAgentTemplates:
user=GraniteUserPromptTemplate,
tool_not_found_error=GraniteToolNotFoundErrorTemplate,
tool_input_error=GraniteToolInputErrorTemplate,
tool_error=GraniteToolErrorTemplate,
schema_error=GraniteSchemaErrorTemplate,
user_empty=UserEmptyPromptTemplate,
tool_no_result_error=ToolNoResultsTemplate,
)
6 changes: 3 additions & 3 deletions python/beeai_framework/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ class BeeAgentTemplates(BaseModel):
system: InstanceOf[PromptTemplate] # TODO proper template subtypes
assistant: InstanceOf[PromptTemplate]
user: InstanceOf[PromptTemplate]
# user_empty: InstanceOf[PromptTemplate]
# tool_error: InstanceOf[PromptTemplate]
user_empty: InstanceOf[PromptTemplate]
tool_error: InstanceOf[PromptTemplate]
tool_input_error: InstanceOf[PromptTemplate]
# tool_no_result_error: InstanceOf[PromptTemplate]
tool_no_result_error: InstanceOf[PromptTemplate]
tool_not_found_error: InstanceOf[PromptTemplate]
schema_error: InstanceOf[PromptTemplate]

Expand Down
2 changes: 1 addition & 1 deletion python/beeai_framework/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async def run_create(context: RunContext) -> ChatModelOutput:
except Exception as ex:
error = ChatModelError.ensure(ex)
await context.emitter.emit("error", {"error": error})
raise error from None
raise error
finally:
await context.emitter.emit("finish", None)

Expand Down
10 changes: 6 additions & 4 deletions python/beeai_framework/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pydantic import ValidationError

from beeai_framework.errors import FrameworkError

Expand All @@ -21,6 +21,8 @@ def __init__(self, message: str = "Tool Error", *, cause: Exception | None = Non
super().__init__(message, is_fatal=True, is_retryable=False, cause=cause)


class ToolInputValidationError(FrameworkError):
def __init__(self, message: str = "Tool Input Validation Error", *, cause: Exception | None = None) -> None:
super().__init__(message, is_fatal=True, is_retryable=False, cause=cause)
class ToolInputValidationError(ToolError):
def __init__(
self, message: str = "Tool Input Validation Error", *, cause: ValidationError | ValueError | None = None
) -> None:
super().__init__(message, cause=cause)
18 changes: 5 additions & 13 deletions python/beeai_framework/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.tools.errors import ToolError, ToolInputValidationError
from beeai_framework.utils import BeeLogger
Expand Down Expand Up @@ -91,17 +90,10 @@ def validate_input(self, input: T | dict[str, Any]) -> T:
try:
return self.input_schema.model_validate(input)
except ValidationError as e:
raise ToolInputValidationError("Tool input validation error") from e
raise ToolInputValidationError("Tool input validation error", cause=e)

def prompt_data(self) -> dict[str, str]:
return {
"name": self.name,
"description": self.description,
"input_schema": str(self.input_schema.model_json_schema(mode="serialization")),
}

def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[Any]:
async def run_tool(context: RunContext) -> Any:
def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[T]:
async def run_tool(context: RunContext) -> T:
error_propagated = False

try:
Expand All @@ -118,7 +110,7 @@ async def executor(_: RetryableContext) -> Any:
async def on_error(error: Exception, _: RetryableContext) -> None:
nonlocal error_propagated
error_propagated = True
err = FrameworkError.ensure(error)
err = ToolError.ensure(error)
await context.emitter.emit("error", {"error": err, **meta})
if err.is_fatal:
raise err from None
Expand All @@ -144,7 +136,7 @@ async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:
err = ToolError.ensure(e)
if not error_propagated:
await context.emitter.emit("error", {"error": err, "input": input, "options": options})
raise err from None
raise err
finally:
await context.emitter.emit("finish", None)

Expand Down
4 changes: 2 additions & 2 deletions python/beeai_framework/tools/weather/openmeteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def _trim_date(date_str: str) -> str:
start = datetime.strptime(_trim_date(start_date), "%Y-%m-%d").replace(tzinfo=UTC)
except ValueError as e:
raise ToolInputValidationError(
"'start_date' is incorrectly formatted, please use the correct format YYYY-MM-DD."
) from e
"'start_date' is incorrectly formatted, please use the correct format YYYY-MM-DD.", cause=e
)
else:
start = datetime.now(UTC)

Expand Down
7 changes: 1 addition & 6 deletions python/docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,13 @@ Customize how the agent formats prompts, including the system prompt that define
from beeai_framework.agents.runners.default.prompts import (
SystemPromptTemplate,
SystemPromptTemplateInput,
ToolDefinition,
)
from beeai_framework.tools.weather.openmeteo import OpenMeteoTool

tool = OpenMeteoTool()

# Render the granite system prompt
prompt = SystemPromptTemplate.render(
SystemPromptTemplateInput(
instructions="You are a helpful AI assistant!", tools=[ToolDefinition(**tool.prompt_data())], tools_length=1
)
)
prompt = SystemPromptTemplate.render(SystemPromptTemplateInput(instructions="You are a helpful AI assistant!"))

print(prompt)

Expand Down
Loading

0 comments on commit 62db3dc

Please sign in to comment.