Skip to content

Commit

Permalink
feat: Add Run.on API to directly add an event callback (#516)
Browse files Browse the repository at this point in the history
This introduces the `Run.on` method that combines the `Run.observe` and `Emitter.match` calls into one more user friendly callback.

Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com>
  • Loading branch information
ajbozarth authored Mar 7, 2025
1 parent e70e56e commit 65abb91
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 35 deletions.
32 changes: 19 additions & 13 deletions python/beeai_framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydantic import BaseModel

from beeai_framework.cancellation import AbortController, AbortSignal, register_signals
from beeai_framework.emitter import Emitter, EventTrace
from beeai_framework.emitter import Callback, Emitter, EmitterOptions, EventTrace, Matcher
from beeai_framework.errors import AbortError, FrameworkError
from beeai_framework.logger import Logger
from beeai_framework.utils.asynchronous import ensure_async
Expand All @@ -34,6 +34,8 @@

logger = Logger(__name__)

storage: ContextVar["RunContext"] = ContextVar("storage")


@dataclass
class RunInstance:
Expand All @@ -49,29 +51,35 @@ class Run(Generic[R]):
def __init__(self, handler: Callable[[], R | Awaitable[R]], context: "RunContext") -> None:
super().__init__()
self.handler = ensure_async(handler)
self.tasks: list[tuple[Callable[[Any], None], Any]] = []
self.tasks: list[tuple[Callable, list]] = []
self.run_context = context

def __await__(self) -> Generator[Any, None, R]:
return self._run_tasks().__await__()

def observe(self, fn: Callable[[Emitter], Any]) -> Self:
self.tasks.append((fn, self.run_context.emitter))
def observe(self, fn: Callable[[Emitter], None]) -> Self:
self.tasks.append((fn, [self.run_context.emitter]))
return self

def on(self, matcher: Matcher, callback: Callback, options: EmitterOptions | None = None) -> Self:
self.tasks.append((self.run_context.emitter.match, [matcher, callback, options]))
return self

def context(self, context: dict) -> Self:
self.tasks.append((self._set_context, context))
self.tasks.append((self._set_context, [context]))
return self

def middleware(self, fn: Callable[["RunContext"], None]) -> Self:
self.tasks.append((fn, self.run_context))
self.tasks.append((fn, [self.run_context]))
return self

async def _run_tasks(self) -> R:
for fn, param in self.tasks:
await ensure_async(fn)(param)

tasks = self.tasks[:]
self.tasks.clear()

for fn, params in tasks:
await ensure_async(fn)(*params)

return await self.handler()

def _set_context(self, context: dict) -> None:
Expand All @@ -80,8 +88,6 @@ def _set_context(self, context: dict) -> None:


class RunContext(RunInstance):
storage: ContextVar[Self] = ContextVar("storage", default=None)

def __init__(self, *, instance: RunInstance, context_input: RunContextInput, parent: Self | None = None) -> None:
self.instance = instance
self.context_input = context_input
Expand Down Expand Up @@ -124,7 +130,7 @@ def destroy(self) -> None:
def enter(
instance: RunInstance, context_input: RunContextInput, fn: Callable[["RunContext"], Awaitable[R]]
) -> Run[R]:
parent = RunContext.storage.get()
parent = storage.get(None)
context = RunContext(instance=instance, context_input=context_input, parent=parent)

async def handler() -> R:
Expand All @@ -134,7 +140,7 @@ async def handler() -> R:
await emitter.emit("start", None)

async def _context_storage_run() -> R:
RunContext.storage.set(context)
storage.set(context)
return await fn(context)

async def _context_signal_aborted() -> None:
Expand Down
12 changes: 4 additions & 8 deletions python/examples/agents/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from beeai_framework.agents.react.types import ReActAgentRunOutput
from beeai_framework.agents.types import AgentExecutionConfig
from beeai_framework.backend.chat import ChatModel
from beeai_framework.emitter import Emitter, EventMeta
from beeai_framework.errors import FrameworkError
from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory
from beeai_framework.tools.search import DuckDuckGoSearchTool
Expand All @@ -25,15 +24,12 @@ async def main() -> None:

prompt = reader.prompt()

def update_callback(data: dict, event: EventMeta) -> None:
reader.write(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])

def on_update(emitter: Emitter) -> None:
emitter.on("update", update_callback)

output: ReActAgentRunOutput = await agent.run(
prompt=prompt, execution=AgentExecutionConfig(total_max_retries=2, max_retries_per_step=3, max_iterations=8)
).observe(on_update)
).on(
"update",
lambda data, event: reader.write(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"]),
)

reader.write("Agent 🤖 : ", output.result.text)

Expand Down
8 changes: 2 additions & 6 deletions python/examples/agents/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from beeai_framework.agents.react.agent import ReActAgent
from beeai_framework.agents.types import AgentExecutionConfig
from beeai_framework.backend.chat import ChatModel, ChatModelParameters
from beeai_framework.emitter.emitter import Emitter, EventMeta
from beeai_framework.emitter.emitter import EventMeta
from beeai_framework.emitter.types import EmitterOptions
from beeai_framework.errors import FrameworkError
from beeai_framework.logger import Logger
Expand Down Expand Up @@ -76,10 +76,6 @@ def process_agent_events(data: dict[str, Any], event: EventMeta) -> None:
reader.write("Agent 🤖 : ", "success")


def observer(emitter: Emitter) -> None:
emitter.on("*", process_agent_events, EmitterOptions(match_nested=False))


async def main() -> None:
"""Main application loop"""

Expand All @@ -102,7 +98,7 @@ async def main() -> None:
response = await agent.run(
prompt=prompt,
execution=AgentExecutionConfig(max_retries_per_step=3, total_max_retries=10, max_iterations=20),
).observe(observer)
).on("*", process_agent_events, EmitterOptions(match_nested=False))

reader.write("Agent 🤖 : ", response.result.text)

Expand Down
11 changes: 3 additions & 8 deletions python/examples/agents/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from beeai_framework.agents.react.agent import ReActAgent
from beeai_framework.agents.react.types import ReActAgentRunOutput
from beeai_framework.backend.chat import ChatModel
from beeai_framework.emitter.emitter import Emitter, EventMeta
from beeai_framework.errors import FrameworkError
from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory
from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool
Expand All @@ -16,13 +15,9 @@ async def main() -> None:
llm = ChatModel.from_name("ollama:granite3.1-dense:8b")
agent = ReActAgent(llm=llm, tools=[DuckDuckGoSearchTool(), OpenMeteoTool()], memory=UnconstrainedMemory())

def update_callback(data: dict, event: EventMeta) -> None:
print(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])

def on_update(emitter: Emitter) -> None:
emitter.on("update", update_callback)

output: ReActAgentRunOutput = await agent.run("What's the current weather in Las Vegas?").observe(on_update)
output: ReActAgentRunOutput = await agent.run("What's the current weather in Las Vegas?").on(
"update", lambda data, event: print(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])
)

print("Agent 🤖 : ", output.result.text)

Expand Down

0 comments on commit 65abb91

Please sign in to comment.