Skip to content

Commit

Permalink
refactor(agents): Add generics to BaseAgent and other mypy fixes (#517)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com>
  • Loading branch information
ajbozarth authored Mar 7, 2025
1 parent ee58c5f commit e70e56e
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 48 deletions.
16 changes: 9 additions & 7 deletions python/beeai_framework/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from beeai_framework.memory import BaseMemory
from beeai_framework.utils.models import ModelLike

T = TypeVar("T", bound=BaseModel)
RI = TypeVar("RI", bound=BaseModel)
RO = TypeVar("RO", bound=BaseModel)
TInput = TypeVar("TInput", bound=BaseModel)
TOptions = TypeVar("TOptions", bound=BaseModel)
TOutput = TypeVar("TOutput", bound=BaseModel)


class BaseAgent(ABC, Generic[T]):
class BaseAgent(ABC, Generic[TInput, TOptions, TOutput]):
is_running: bool = False
emitter: Emitter

Expand All @@ -39,13 +39,13 @@ def run(
prompt: str | None = None,
execution: AgentExecutionConfig | None = None,
signal: AbortSignal | None = None,
) -> Run[T]:
) -> Run[TOutput]:
if self.is_running:
raise RuntimeError("Agent is already running!")

self.is_running = True

async def handler(context: RunContext) -> T:
async def handler(context: RunContext) -> TOutput:
try:
return await self._run({"prompt": prompt}, {"execution": execution, "signal": signal}, context)
finally:
Expand All @@ -61,7 +61,9 @@ async def handler(context: RunContext) -> T:
)

@abstractmethod
async def _run(self, run_input: ModelLike[RI], options: ModelLike[RO] | None, context: RunContext) -> T:
async def _run(
self, run_input: ModelLike[TInput], options: ModelLike[TOptions] | None, context: RunContext
) -> TOutput:
pass

def destroy(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions python/beeai_framework/agents/react/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
ReActAgentRunOptions,
ReActAgentRunOutput,
ReActAgentTemplateFactory,
ReActAgentTemplates,
)
from beeai_framework.agents.types import (
AgentExecutionConfig,
Expand All @@ -44,11 +43,12 @@
from beeai_framework.context import RunContext
from beeai_framework.emitter import Emitter
from beeai_framework.memory import BaseMemory
from beeai_framework.template import PromptTemplate
from beeai_framework.tools.tool import Tool
from beeai_framework.utils.models import ModelLike, to_model, to_model_optional


class ReActAgent(BaseAgent[ReActAgentRunOutput]):
class ReActAgent(BaseAgent[ReActAgentRunInput, ReActAgentRunOptions, ReActAgentRunOutput]):
runner: Callable[..., BaseRunner]

def __init__(
Expand All @@ -57,7 +57,7 @@ def __init__(
tools: list[Tool],
memory: BaseMemory,
meta: AgentMeta | None = None,
templates: dict[ModelKeysType, ReActAgentTemplates | ReActAgentTemplateFactory] | None = None,
templates: dict[ModelKeysType, PromptTemplate | ReActAgentTemplateFactory] | None = None,
execution: AgentExecutionConfig | None = None,
stream: bool | None = None,
) -> None:
Expand Down
33 changes: 15 additions & 18 deletions python/beeai_framework/agents/react/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass

from pydantic import BaseModel, InstanceOf

from beeai_framework.agents import AgentError
from beeai_framework.agents.react.types import (
Expand All @@ -36,33 +37,29 @@
from beeai_framework.utils.counter import RetryCounter


@dataclass
class ReActAgentRunnerLLMInput:
class ReActAgentRunnerLLMInput(BaseModel):
meta: ReActAgentIterationMeta
signal: AbortSignal
emitter: Emitter
signal: InstanceOf[AbortSignal]
emitter: InstanceOf[Emitter]


@dataclass
class ReActAgentRunnerIteration:
emitter: Emitter
state: ReActAgentIterationResult
class ReActAgentRunnerIteration(BaseModel):
emitter: InstanceOf[Emitter]
state: InstanceOf[ReActAgentIterationResult]
meta: ReActAgentIterationMeta
signal: AbortSignal
signal: InstanceOf[AbortSignal]


@dataclass
class ReActAgentRunnerToolResult:
output: ToolOutput
class ReActAgentRunnerToolResult(BaseModel):
output: InstanceOf[ToolOutput]
success: bool


@dataclass
class ReActAgentRunnerToolInput:
state: ReActAgentIterationResult
class ReActAgentRunnerToolInput(BaseModel):
state: InstanceOf[ReActAgentIterationResult]
meta: ReActAgentIterationMeta
signal: AbortSignal
emitter: Emitter
signal: InstanceOf[AbortSignal]
emitter: InstanceOf[Emitter]


class BaseRunner(ABC):
Expand Down
6 changes: 4 additions & 2 deletions python/beeai_framework/agents/react/runners/default/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.tools import ToolError, ToolInputValidationError
from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput
from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput, ToolRunOptions
from beeai_framework.utils.strings import create_strenum, to_json


Expand Down Expand Up @@ -267,7 +267,9 @@ async def on_error(error: Exception, _: RetryableContext) -> None:

async def executor(_: RetryableContext) -> ReActAgentRunnerToolResult:
try:
tool_output: ToolOutput = await tool.run(input.state.tool_input, options={}) # TODO: pass tool options
tool_output: ToolOutput = await tool.run(
input.state.tool_input, options=ToolRunOptions()
) # TODO: pass tool options
output = (
tool_output
if not tool_output.is_empty()
Expand Down
7 changes: 4 additions & 3 deletions python/beeai_framework/agents/react/runners/granite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ async def on_update(data: dict, event: EventMeta) -> None:
update = data.get("update")
assert update is not None
if update.get("key") == "tool_output":
memory: BaseMemory = data.get("memory")
memory = data.get("memory")
assert isinstance(memory, BaseMemory)
tool_output: ToolOutput = update.get("value")
tool_result = MessageToolResultContent(
result=tool_output.get_text_content(),
tool_name=data.get("data").tool_name,
tool_name=data["data"].tool_name,
tool_call_id="DUMMY_ID",
)
await memory.add(
ToolMessage(
content=tool_result,
meta={"success": data.get("meta").get("success", True)},
meta={"success": data["meta"]["success"] or True},
)
)

Expand Down
2 changes: 1 addition & 1 deletion python/beeai_framework/agents/react/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def to_template(self) -> dict:

class ReActAgentRunIteration(BaseModel):
raw: InstanceOf[ChatModelOutput]
state: ReActAgentIterationResult
state: InstanceOf[ReActAgentIterationResult]


class ReActAgentRunOutput(BaseModel):
Expand Down
8 changes: 2 additions & 6 deletions python/docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,7 @@ class RunOutput(BaseModel):
state: State


class RunOptions(ReActAgentRunOptions):
max_retries: int | None = None


class CustomAgent(BaseAgent[RunOutput]):
class CustomAgent(BaseAgent[ReActAgentRunInput, ReActAgentRunOptions, RunOutput]):
memory: BaseMemory | None = None

def __init__(self, llm: ChatModel, memory: BaseMemory) -> None:
Expand Down Expand Up @@ -301,7 +297,7 @@ class CustomAgent(BaseAgent[RunOutput]):
messages=[
SystemMessage("You are a helpful assistant. Always use JSON format for your responses."),
*(self.memory.messages if self.memory is not None else []),
UserMessage(run_input.prompt),
UserMessage(run_input.prompt or ""),
],
max_retries=options.execution.total_max_retries if options and options.execution else None,
abort_signal=context.signal,
Expand Down
8 changes: 2 additions & 6 deletions python/examples/agents/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ class RunOutput(BaseModel):
state: State


class RunOptions(ReActAgentRunOptions):
max_retries: int | None = None


class CustomAgent(BaseAgent[RunOutput]):
class CustomAgent(BaseAgent[ReActAgentRunInput, ReActAgentRunOptions, RunOutput]):
memory: BaseMemory | None = None

def __init__(self, llm: ChatModel, memory: BaseMemory) -> None:
Expand Down Expand Up @@ -67,7 +63,7 @@ class CustomSchema(BaseModel):
messages=[
SystemMessage("You are a helpful assistant. Always use JSON format for your responses."),
*(self.memory.messages if self.memory is not None else []),
UserMessage(run_input.prompt),
UserMessage(run_input.prompt or ""),
],
max_retries=options.execution.total_max_retries if options and options.execution else None,
abort_signal=context.signal,
Expand Down
6 changes: 4 additions & 2 deletions python/tests/runners/test_default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
AgentExecutionConfig,
)
from beeai_framework.backend.chat import ChatModel
from beeai_framework.cancellation import AbortSignal
from beeai_framework.emitter import Emitter
from beeai_framework.memory.token_memory import TokenMemory
from beeai_framework.tools.weather.openmeteo import OpenMeteoTool

Expand Down Expand Up @@ -56,8 +58,8 @@ async def test_runner_init() -> None:
await runner.tool(
input=ReActAgentRunnerToolInput(
state=ReActAgentIterationResult(tool_name="OpenMeteoTool", tool_input={"location_name": "White Plains"}),
emitter=None,
emitter=Emitter(),
meta=ReActAgentIterationMeta(iteration=0),
signal=None,
signal=AbortSignal(),
)
)

0 comments on commit e70e56e

Please sign in to comment.