diff --git a/python/beeai_framework/adapters/litellm/chat.py b/python/beeai_framework/adapters/litellm/chat.py index 987f09690..2da6e32be 100644 --- a/python/beeai_framework/adapters/litellm/chat.py +++ b/python/beeai_framework/adapters/litellm/chat.py @@ -159,7 +159,21 @@ def _transform_input(self, input: ChatModelInput) -> LiteLLMParameters: else: messages.append(message.to_plain()) - tools = [{"type": "function", "function": tool.prompt_data()} for tool in input.tools] if input.tools else None + tools = ( + [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema.model_json_schema(mode="validation"), + }, + } + for tool in input.tools + ] + if input.tools + else None + ) return LiteLLMParameters( model=f"{self._litellm_provider_id}/{self.model_id}", diff --git a/python/beeai_framework/agents/runners/base.py b/python/beeai_framework/agents/runners/base.py index e0ae7f6da..3ce42f210 100644 --- a/python/beeai_framework/agents/runners/base.py +++ b/python/beeai_framework/agents/runners/base.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import math from abc import ABC, abstractmethod from dataclasses import dataclass @@ -73,7 +72,9 @@ def __init__(self, input: BeeInput, options: BeeRunOptions, run: RunContext) -> self._failed_attempts_counter: RetryCounter = RetryCounter( error_type=AgentError, max_retries=( - options.execution.total_max_retries if options.execution and options.execution.total_max_retries else 0 + options.execution.total_max_retries + if options.execution and options.execution.total_max_retries + else math.inf ), ) self._run = run @@ -93,7 +94,7 @@ async def create_iteration(self) -> RunnerIteration: max_iterations = ( self._options.execution.max_iterations if self._options.execution and self._options.execution.max_iterations - else 0 + else math.inf ) if meta.iteration > max_iterations: diff --git a/python/beeai_framework/agents/runners/default/prompts.py b/python/beeai_framework/agents/runners/default/prompts.py index 5a05f7645..348426f70 100644 --- a/python/beeai_framework/agents/runners/default/prompts.py +++ b/python/beeai_framework/agents/runners/default/prompts.py @@ -17,6 +17,14 @@ from beeai_framework.template import PromptTemplate, PromptTemplateInput +class UserEmptyPromptTemplateInput(BaseModel): + pass + + +class ToolNoResultsTemplateInput(BaseModel): + pass + + class UserPromptTemplateInput(BaseModel): input: str @@ -49,6 +57,10 @@ class ToolInputErrorTemplateInput(BaseModel): reason: str +class ToolErrorTemplateInput(BaseModel): + reason: str + + class SchemaErrorTemplateInput(BaseModel): pass @@ -149,10 +161,33 @@ class SchemaErrorTemplateInput(BaseModel): ) ) +ToolNoResultsTemplate = PromptTemplate( + PromptTemplateInput( + schema=ToolNoResultsTemplateInput, + template="""No results were found!""", + ) +) + +UserEmptyPromptTemplate = PromptTemplate( + PromptTemplateInput( + schema=UserEmptyPromptTemplateInput, + template="""Message: Empty message.""", + ) +) + +ToolErrorTemplate = PromptTemplate( + PromptTemplateInput( + schema=ToolErrorTemplateInput, + template="""The function has failed; the error log is shown below. If the function cannot accomplish what you want, use a different function or explain why you can't use it. + +{{&reason}}""", # noqa: E501 + ) +) + ToolInputErrorTemplate = PromptTemplate( PromptTemplateInput( schema=ToolInputErrorTemplateInput, - template="""{{reason}} + template="""{{&reason}} HINT: If you're convinced that the input was correct but the function cannot process it then use a different function or say I don't know.""", # noqa: E501 ) diff --git a/python/beeai_framework/agents/runners/default/runner.py b/python/beeai_framework/agents/runners/default/runner.py index 7c214a40a..e3184eee3 100644 --- a/python/beeai_framework/agents/runners/default/runner.py +++ b/python/beeai_framework/agents/runners/default/runner.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from collections.abc import Callable from beeai_framework.agents.runners.base import ( @@ -28,8 +27,11 @@ SystemPromptTemplate, SystemPromptTemplateInput, ToolDefinition, + ToolErrorTemplate, ToolInputErrorTemplate, + ToolNoResultsTemplate, ToolNotFoundErrorTemplate, + UserEmptyPromptTemplate, UserPromptTemplate, ) from beeai_framework.agents.types import ( @@ -54,7 +56,7 @@ from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools import ToolError, ToolInputValidationError from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput -from beeai_framework.utils.strings import create_strenum +from beeai_framework.utils.strings import create_strenum, to_json class DefaultRunner(BaseRunner): @@ -65,7 +67,10 @@ def default_templates(self) -> BeeAgentTemplates: system=SystemPromptTemplate, assistant=AssistantPromptTemplate, user=UserPromptTemplate, + user_empty=UserEmptyPromptTemplate, tool_not_found_error=ToolNotFoundErrorTemplate, + tool_no_result_error=ToolNoResultsTemplate, + tool_error=ToolErrorTemplate, tool_input_error=ToolInputErrorTemplate, schema_error=SchemaErrorTemplate, ) @@ -237,27 +242,27 @@ async def on_error(error: Exception, _: RetryableContext) -> None: async def executor(_: RetryableContext) -> BeeRunnerToolResult: try: tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options - return BeeRunnerToolResult(output=tool_output, success=True) - # TODO These error templates should be customized to help the LLM to recover - except ToolInputValidationError as e: - self._failed_attempts_counter.use(e) + output = ( + tool_output + if not tool_output.is_empty() + else StringToolOutput(self.templates.tool_no_result_error.render({})) + ) return BeeRunnerToolResult( - success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + output=output, + success=True, ) - - except ToolError as e: + except ToolInputValidationError as e: self._failed_attempts_counter.use(e) - return BeeRunnerToolResult( success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + output=StringToolOutput(self.templates.tool_input_error.render({"reason": e.explain()})), ) - except json.JSONDecodeError as e: - self._failed_attempts_counter.use(e) + except Exception as e: + err = ToolError.ensure(e) + self._failed_attempts_counter.use(err) return BeeRunnerToolResult( success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + output=StringToolOutput(self.templates.tool_error.render({"reason": err.explain()})), ) if self._options and self._options.execution and self._options.execution.max_retries_per_step: @@ -278,10 +283,14 @@ async def init_memory(self, input: BeeRunInput) -> BaseMemory: capacity_threshold=0.85, sync_threshold=0.5, llm=self._input.llm ) # TODO handlers needs to be fixed - tool_defs = [] - - for tool in self._input.tools: - tool_defs.append(ToolDefinition(**tool.prompt_data())) + tool_defs = [ + ToolDefinition( + name=tool.name, + description=tool.description, + input_schema=to_json(tool.input_schema.model_json_schema(mode="validation")), + ) + for tool in self._input.tools + ] system_prompt: str = self.templates.system.render( SystemPromptTemplateInput( diff --git a/python/beeai_framework/agents/runners/granite/prompts.py b/python/beeai_framework/agents/runners/granite/prompts.py index 7f7850504..3c96477f3 100644 --- a/python/beeai_framework/agents/runners/granite/prompts.py +++ b/python/beeai_framework/agents/runners/granite/prompts.py @@ -18,6 +18,7 @@ AssistantPromptTemplateInput, SchemaErrorTemplateInput, SystemPromptTemplateInput, + ToolErrorTemplateInput, ToolInputErrorTemplateInput, ToolNotFoundErrorTemplateInput, UserPromptTemplateInput, @@ -96,12 +97,21 @@ GraniteToolInputErrorTemplate = PromptTemplate( PromptTemplateInput( schema=ToolInputErrorTemplateInput, - template="""{{reason}} + template="""{{&reason}} HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.""", # noqa: E501 ) ) +GraniteToolErrorTemplate = PromptTemplate( + PromptTemplateInput( + schema=ToolErrorTemplateInput, + template="""The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. + +{{&reason}}""", # noqa: E501 + ) +) + GraniteSchemaErrorTemplate = PromptTemplate( PromptTemplateInput( schema=SchemaErrorTemplateInput, diff --git a/python/beeai_framework/agents/runners/granite/runner.py b/python/beeai_framework/agents/runners/granite/runner.py index 5d0aeead0..eaa872e00 100644 --- a/python/beeai_framework/agents/runners/granite/runner.py +++ b/python/beeai_framework/agents/runners/granite/runner.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +from beeai_framework.agents.runners.default.prompts import ToolNoResultsTemplate, UserEmptyPromptTemplate from beeai_framework.agents.runners.default.runner import DefaultRunner from beeai_framework.agents.runners.granite.prompts import ( GraniteAssistantPromptTemplate, GraniteSchemaErrorTemplate, GraniteSystemPromptTemplate, + GraniteToolErrorTemplate, GraniteToolInputErrorTemplate, GraniteToolNotFoundErrorTemplate, GraniteUserPromptTemplate, @@ -95,5 +95,8 @@ def default_templates(self) -> BeeAgentTemplates: user=GraniteUserPromptTemplate, tool_not_found_error=GraniteToolNotFoundErrorTemplate, tool_input_error=GraniteToolInputErrorTemplate, + tool_error=GraniteToolErrorTemplate, schema_error=GraniteSchemaErrorTemplate, + user_empty=UserEmptyPromptTemplate, + tool_no_result_error=ToolNoResultsTemplate, ) diff --git a/python/beeai_framework/agents/types.py b/python/beeai_framework/agents/types.py index 5e221aa8c..edd88a9bc 100644 --- a/python/beeai_framework/agents/types.py +++ b/python/beeai_framework/agents/types.py @@ -77,10 +77,10 @@ class BeeAgentTemplates(BaseModel): system: InstanceOf[PromptTemplate] # TODO proper template subtypes assistant: InstanceOf[PromptTemplate] user: InstanceOf[PromptTemplate] - # user_empty: InstanceOf[PromptTemplate] - # tool_error: InstanceOf[PromptTemplate] + user_empty: InstanceOf[PromptTemplate] + tool_error: InstanceOf[PromptTemplate] tool_input_error: InstanceOf[PromptTemplate] - # tool_no_result_error: InstanceOf[PromptTemplate] + tool_no_result_error: InstanceOf[PromptTemplate] tool_not_found_error: InstanceOf[PromptTemplate] schema_error: InstanceOf[PromptTemplate] diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py index e966aff2b..65883e481 100644 --- a/python/beeai_framework/backend/chat.py +++ b/python/beeai_framework/backend/chat.py @@ -254,7 +254,7 @@ async def run_create(context: RunContext) -> ChatModelOutput: except Exception as ex: error = ChatModelError.ensure(ex) await context.emitter.emit("error", {"error": error}) - raise error from None + raise error finally: await context.emitter.emit("finish", None) diff --git a/python/beeai_framework/tools/errors.py b/python/beeai_framework/tools/errors.py index 173f71e50..b96659c74 100644 --- a/python/beeai_framework/tools/errors.py +++ b/python/beeai_framework/tools/errors.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from pydantic import ValidationError from beeai_framework.errors import FrameworkError @@ -21,6 +21,8 @@ def __init__(self, message: str = "Tool Error", *, cause: Exception | None = Non super().__init__(message, is_fatal=True, is_retryable=False, cause=cause) -class ToolInputValidationError(FrameworkError): - def __init__(self, message: str = "Tool Input Validation Error", *, cause: Exception | None = None) -> None: - super().__init__(message, is_fatal=True, is_retryable=False, cause=cause) +class ToolInputValidationError(ToolError): + def __init__( + self, message: str = "Tool Input Validation Error", *, cause: ValidationError | ValueError | None = None + ) -> None: + super().__init__(message, cause=cause) diff --git a/python/beeai_framework/tools/tool.py b/python/beeai_framework/tools/tool.py index 0102bcc57..931bf5295 100644 --- a/python/beeai_framework/tools/tool.py +++ b/python/beeai_framework/tools/tool.py @@ -22,7 +22,6 @@ from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance from beeai_framework.emitter.emitter import Emitter -from beeai_framework.errors import FrameworkError from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools.errors import ToolError, ToolInputValidationError from beeai_framework.utils import BeeLogger @@ -91,17 +90,10 @@ def validate_input(self, input: T | dict[str, Any]) -> T: try: return self.input_schema.model_validate(input) except ValidationError as e: - raise ToolInputValidationError("Tool input validation error") from e + raise ToolInputValidationError("Tool input validation error", cause=e) - def prompt_data(self) -> dict[str, str]: - return { - "name": self.name, - "description": self.description, - "input_schema": str(self.input_schema.model_json_schema(mode="serialization")), - } - - def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[Any]: - async def run_tool(context: RunContext) -> Any: + def run(self, input: T | dict[str, Any], options: dict[str, Any] | None = None) -> Run[T]: + async def run_tool(context: RunContext) -> T: error_propagated = False try: @@ -118,7 +110,7 @@ async def executor(_: RetryableContext) -> Any: async def on_error(error: Exception, _: RetryableContext) -> None: nonlocal error_propagated error_propagated = True - err = FrameworkError.ensure(error) + err = ToolError.ensure(error) await context.emitter.emit("error", {"error": err, **meta}) if err.is_fatal: raise err from None @@ -144,7 +136,7 @@ async def on_retry(ctx: RetryableContext, last_error: Exception) -> None: err = ToolError.ensure(e) if not error_propagated: await context.emitter.emit("error", {"error": err, "input": input, "options": options}) - raise err from None + raise err finally: await context.emitter.emit("finish", None) diff --git a/python/beeai_framework/tools/weather/openmeteo.py b/python/beeai_framework/tools/weather/openmeteo.py index 075375fba..162b03665 100644 --- a/python/beeai_framework/tools/weather/openmeteo.py +++ b/python/beeai_framework/tools/weather/openmeteo.py @@ -106,8 +106,8 @@ def _trim_date(date_str: str) -> str: start = datetime.strptime(_trim_date(start_date), "%Y-%m-%d").replace(tzinfo=UTC) except ValueError as e: raise ToolInputValidationError( - "'start_date' is incorrectly formatted, please use the correct format YYYY-MM-DD." - ) from e + "'start_date' is incorrectly formatted, please use the correct format YYYY-MM-DD.", cause=e + ) else: start = datetime.now(UTC) diff --git a/python/docs/agents.md b/python/docs/agents.md index 6098a1d49..267b202cc 100644 --- a/python/docs/agents.md +++ b/python/docs/agents.md @@ -112,18 +112,13 @@ Customize how the agent formats prompts, including the system prompt that define from beeai_framework.agents.runners.default.prompts import ( SystemPromptTemplate, SystemPromptTemplateInput, - ToolDefinition, ) from beeai_framework.tools.weather.openmeteo import OpenMeteoTool tool = OpenMeteoTool() # Render the granite system prompt -prompt = SystemPromptTemplate.render( - SystemPromptTemplateInput( - instructions="You are a helpful AI assistant!", tools=[ToolDefinition(**tool.prompt_data())], tools_length=1 - ) -) +prompt = SystemPromptTemplate.render(SystemPromptTemplateInput(instructions="You are a helpful AI assistant!")) print(prompt) diff --git a/python/docs/templates.md b/python/docs/templates.md index 97c995365..80c84c31b 100644 --- a/python/docs/templates.md +++ b/python/docs/templates.md @@ -172,18 +172,13 @@ The framework's agents use specialized templates to structure their behavior. Yo from beeai_framework.agents.runners.default.prompts import ( SystemPromptTemplate, SystemPromptTemplateInput, - ToolDefinition, ) from beeai_framework.tools.weather.openmeteo import OpenMeteoTool tool = OpenMeteoTool() # Render the granite system prompt -prompt = SystemPromptTemplate.render( - SystemPromptTemplateInput( - instructions="You are a helpful AI assistant!", tools=[ToolDefinition(**tool.prompt_data())], tools_length=1 - ) -) +prompt = SystemPromptTemplate.render(SystemPromptTemplateInput(instructions="You are a helpful AI assistant!")) print(prompt) diff --git a/python/examples/README.md b/python/examples/README.md index 3872fbfcb..5d8236b6e 100644 --- a/python/examples/README.md +++ b/python/examples/README.md @@ -50,7 +50,6 @@ This repository contains examples demonstrating the usage of the BeeAI Framework - [`basic_functions.py`](/python/examples/templates/basic_functions.py): Basic functions - [`basic_template.py`](/python/examples/templates/basic_template.py): Basic template -- [`agent_sys_prompt.py`](/python/examples/templates/agent_sys_prompt.py): System Prompt ## Tools diff --git a/python/examples/templates/agent_sys_prompt.py b/python/examples/templates/agent_sys_prompt.py index b281add17..296b8c000 100644 --- a/python/examples/templates/agent_sys_prompt.py +++ b/python/examples/templates/agent_sys_prompt.py @@ -1,17 +1,12 @@ from beeai_framework.agents.runners.default.prompts import ( SystemPromptTemplate, SystemPromptTemplateInput, - ToolDefinition, ) from beeai_framework.tools.weather.openmeteo import OpenMeteoTool tool = OpenMeteoTool() # Render the granite system prompt -prompt = SystemPromptTemplate.render( - SystemPromptTemplateInput( - instructions="You are a helpful AI assistant!", tools=[ToolDefinition(**tool.prompt_data())], tools_length=1 - ) -) +prompt = SystemPromptTemplate.render(SystemPromptTemplateInput(instructions="You are a helpful AI assistant!")) print(prompt)