Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Run.on API to directly add an event callback #516

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions python/beeai_framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydantic import BaseModel

from beeai_framework.cancellation import AbortController, AbortSignal, register_signals
from beeai_framework.emitter import Emitter, EventTrace
from beeai_framework.emitter import Callback, Emitter, EmitterOptions, EventTrace, Matcher
from beeai_framework.errors import AbortError, FrameworkError
from beeai_framework.logger import Logger
from beeai_framework.utils.asynchronous import ensure_async
Expand All @@ -34,6 +34,8 @@

logger = Logger(__name__)

storage: ContextVar["RunContext"] = ContextVar("storage")


@dataclass
class RunInstance:
Expand All @@ -49,27 +51,31 @@ class Run(Generic[R]):
def __init__(self, handler: Callable[[], R | Awaitable[R]], context: "RunContext") -> None:
super().__init__()
self.handler = ensure_async(handler)
self.tasks: list[tuple[Callable[[Any], None], Any]] = []
self.tasks: list[tuple[Callable, list]] = []
self.run_context = context

def __await__(self) -> Generator[Any, None, R]:
return self._run_tasks().__await__()

def observe(self, fn: Callable[[Emitter], Any]) -> Self:
self.tasks.append((fn, self.run_context.emitter))
def observe(self, fn: Callable[[Emitter], None]) -> Self:
self.tasks.append((fn, [self.run_context.emitter]))
return self

def on(self, matcher: Matcher, callback: Callback, options: EmitterOptions | None = None) -> Self:
self.tasks.append((self.run_context.emitter.match, [matcher, callback, options]))
return self

def context(self, context: dict) -> Self:
self.tasks.append((self._set_context, context))
self.tasks.append((self._set_context, [context]))
return self

def middleware(self, fn: Callable[["RunContext"], None]) -> Self:
self.tasks.append((fn, self.run_context))
self.tasks.append((fn, [self.run_context]))
return self

async def _run_tasks(self) -> R:
for fn, param in self.tasks:
await ensure_async(fn)(param)
for fn, params in self.tasks:
await ensure_async(fn)(*params)

self.tasks.clear()
return await self.handler()
Expand All @@ -80,8 +86,6 @@ def _set_context(self, context: dict) -> None:


class RunContext(RunInstance):
storage: ContextVar[Self] = ContextVar("storage", default=None)

def __init__(self, *, instance: RunInstance, context_input: RunContextInput, parent: Self | None = None) -> None:
self.instance = instance
self.context_input = context_input
Expand Down Expand Up @@ -124,7 +128,7 @@ def destroy(self) -> None:
def enter(
instance: RunInstance, context_input: RunContextInput, fn: Callable[["RunContext"], Awaitable[R]]
) -> Run[R]:
parent = RunContext.storage.get()
parent = storage.get(None)
context = RunContext(instance=instance, context_input=context_input, parent=parent)

async def handler() -> R:
Expand All @@ -134,7 +138,7 @@ async def handler() -> R:
await emitter.emit("start", None)

async def _context_storage_run() -> R:
RunContext.storage.set(context)
storage.set(context)
return await fn(context)

async def _context_signal_aborted() -> None:
Expand Down
12 changes: 4 additions & 8 deletions python/examples/agents/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from beeai_framework.agents.react.types import ReActAgentRunOutput
from beeai_framework.agents.types import AgentExecutionConfig
from beeai_framework.backend.chat import ChatModel
from beeai_framework.emitter import Emitter, EventMeta
from beeai_framework.errors import FrameworkError
from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory
from beeai_framework.tools.search import DuckDuckGoSearchTool
Expand All @@ -25,15 +24,12 @@ async def main() -> None:

prompt = reader.prompt()

def update_callback(data: dict, event: EventMeta) -> None:
reader.write(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])

def on_update(emitter: Emitter) -> None:
emitter.on("update", update_callback)

output: ReActAgentRunOutput = await agent.run(
prompt=prompt, execution=AgentExecutionConfig(total_max_retries=2, max_retries_per_step=3, max_iterations=8)
).observe(on_update)
).on(
"update",
lambda data, event: reader.write(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"]),
)

reader.write("Agent 🤖 : ", output.result.text)

Expand Down
8 changes: 2 additions & 6 deletions python/examples/agents/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from beeai_framework.agents.react.agent import ReActAgent
from beeai_framework.agents.types import AgentExecutionConfig
from beeai_framework.backend.chat import ChatModel, ChatModelParameters
from beeai_framework.emitter.emitter import Emitter, EventMeta
from beeai_framework.emitter.emitter import EventMeta
from beeai_framework.emitter.types import EmitterOptions
from beeai_framework.errors import FrameworkError
from beeai_framework.logger import Logger
Expand Down Expand Up @@ -76,10 +76,6 @@ def process_agent_events(data: dict[str, Any], event: EventMeta) -> None:
reader.write("Agent 🤖 : ", "success")


def observer(emitter: Emitter) -> None:
emitter.on("*", process_agent_events, EmitterOptions(match_nested=False))


async def main() -> None:
"""Main application loop"""

Expand All @@ -102,7 +98,7 @@ async def main() -> None:
response = await agent.run(
prompt=prompt,
execution=AgentExecutionConfig(max_retries_per_step=3, total_max_retries=10, max_iterations=20),
).observe(observer)
).on("*", process_agent_events, EmitterOptions(match_nested=False))

reader.write("Agent 🤖 : ", response.result.text)

Expand Down
11 changes: 3 additions & 8 deletions python/examples/agents/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from beeai_framework.agents.react.agent import ReActAgent
from beeai_framework.agents.react.types import ReActAgentRunOutput
from beeai_framework.backend.chat import ChatModel
from beeai_framework.emitter.emitter import Emitter, EventMeta
from beeai_framework.errors import FrameworkError
from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory
from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool
Expand All @@ -16,13 +15,9 @@ async def main() -> None:
llm = ChatModel.from_name("ollama:granite3.1-dense:8b")
agent = ReActAgent(llm=llm, tools=[DuckDuckGoSearchTool(), OpenMeteoTool()], memory=UnconstrainedMemory())

def update_callback(data: dict, event: EventMeta) -> None:
print(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])

def on_update(emitter: Emitter) -> None:
emitter.on("update", update_callback)

output: ReActAgentRunOutput = await agent.run("What's the current weather in Las Vegas?").observe(on_update)
output: ReActAgentRunOutput = await agent.run("What's the current weather in Las Vegas?").on(
"update", lambda data, event: print(f"Agent({data['update']['key']}) 🤖 : ", data["update"]["parsedValue"])
)

print("Agent 🤖 : ", output.result.text)

Expand Down