Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add runcontext and emitter to workflow #416

Merged
merged 15 commits into from
Feb 28, 2025
2 changes: 1 addition & 1 deletion python/beeai_framework/emitter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def assert_valid_name(name: str) -> None:
if not name or not re.match("^[a-zA-Z0-9_]+$", name):
raise EmitterError(
"Event name or a namespace part must contain only letters, numbers or underscores.",
f"Event name or a namespace part must contain only letters, numbers or underscores: {name}",
)


Expand Down
4 changes: 2 additions & 2 deletions python/beeai_framework/workflows/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class AgentWorkflow:
def __init__(self, name: str = "AgentWorkflow") -> None:
self.workflow = Workflow(name=name, schema=Schema)

async def run(self, messages: list[Message]) -> WorkflowRun:
return await self.workflow.run(Schema(messages=messages))
def run(self, messages: list[Message]) -> WorkflowRun:
return self.workflow.run(Schema(messages=messages))

def del_agent(self, name: str) -> "AgentWorkflow":
self.workflow.delete_step(name)
Expand Down
131 changes: 96 additions & 35 deletions python/beeai_framework/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

import asyncio
import inspect
import re
from collections.abc import Awaitable, Callable
from dataclasses import field
from typing import ClassVar, Final, Generic, Literal
from typing import Any, ClassVar, Final, Generic, Literal

from pydantic import BaseModel
from typing_extensions import TypeVar

from beeai_framework.utils.models import ModelLike, check_model, to_model
from beeai_framework.cancellation import AbortSignal
from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.emitter.types import EmitterInput
from beeai_framework.utils.models import ModelLike, check_model, to_model, to_model_optional
from beeai_framework.utils.types import MaybeAsync
from beeai_framework.workflows.errors import WorkflowError

Expand All @@ -46,12 +52,23 @@ class WorkflowStepDefinition(BaseModel, Generic[T, K]):
handler: WorkflowHandler[T, K]


class WorkflowRunContext(BaseModel, Generic[T, K]):
steps: list[WorkflowStepRes[T, K]] = field(default_factory=list)
signal: AbortSignal
abort: Callable[[Any], None]


class WorkflowRun(BaseModel, Generic[T, K]):
state: T
result: T | None = None
steps: list[WorkflowStepRes[T, K]] = field(default_factory=list)


class WorkflowRunOptions(BaseModel, Generic[K]):
start: K | None = None
signal: AbortSignal | None = None


class Workflow(Generic[T, K]):
START: Final[Literal["__start__"]] = "__start__"
SELF: Final[Literal["__self__"]] = "__self__"
Expand All @@ -61,12 +78,23 @@ class Workflow(Generic[T, K]):

_RESERVED_STEP_NAMES: ClassVar = [START, SELF, PREV, NEXT, END]

emitter: Emitter

def __init__(self, schema: type[T], name: str = "Workflow") -> None:
self._name = name
self._schema = schema
self._steps: dict[K, WorkflowStepDefinition[T, K]] = {}
self._start_step: K | None = None

# replace any non-alphanumeric char with _
formatted_name = re.sub(r"\W+", "_", self._name).lower()
self.emitter = Emitter.root().child(
EmitterInput(
namespace=["workflow", formatted_name],
creator=self,
)
)

@property
def steps(self) -> dict[K, WorkflowStepDefinition[T, K]]:
return self._steps
Expand Down Expand Up @@ -116,39 +144,72 @@ def set_start(self, name: K) -> "Workflow[T, K]":
self._start_step = name
return self

async def run(self, state: ModelLike[T]) -> WorkflowRun[T, K]:
run = WorkflowRun[T, K](state=to_model(self._schema, state))
next = self._find_step(self.start_step or self.step_names[0]).current or Workflow.END

while next and next != Workflow.END:
step = self.steps.get(next)
if step is None:
raise WorkflowError(f"Step '{next}' was not found.")

step_res = WorkflowStepRes[T, K](name=next, state=run.state.model_copy(deep=True))
run.steps.append(step_res)

if inspect.iscoroutinefunction(step.handler):
step_next = await step.handler(step_res.state)
else:
step_next = await asyncio.to_thread(step.handler, step_res.state)

check_model(step_res.state)
run.state = step_res.state

# Route to next step
if step_next == Workflow.START:
next = run.steps[0].name
elif step_next == Workflow.PREV:
next = run.steps[-2].name
elif step_next == Workflow.SELF:
next = run.steps[-1].name
elif step_next is None or step_next == Workflow.NEXT:
next = self._find_step(next).next or Workflow.END
else:
next = step_next

return run
def run(self, state: ModelLike[T], options: ModelLike[WorkflowRunOptions] | None = None) -> Run[WorkflowRun[T, K]]:
options = to_model_optional(WorkflowRunOptions, options)

async def run_workflow(context: RunContext) -> Awaitable[WorkflowRun[T, K]]:
run = WorkflowRun[T, K](state=to_model(self._schema, state))
# handlers = WorkflowRunContext(steps=run.steps, signal=context.signal, abort=lambda r: context.abort(r))
next = self._find_step(self.start_step or self.step_names[0]).current or Workflow.END

while next and next != Workflow.END:
step = self.steps.get(next)
if step is None:
raise WorkflowError(f"Step '{next}' was not found.")

await context.emitter.emit("start", {"run": run, "step": next})

try:
step_res = WorkflowStepRes[T, K](name=next, state=run.state.model_copy(deep=True))
run.steps.append(step_res)

if inspect.iscoroutinefunction(step.handler):
step_next = await step.handler(step_res.state) # , handlers)
else:
step_next = await asyncio.to_thread(step.handler, step_res.state) # handlers)

check_model(step_res.state)
run.state = step_res.state

# Route to next step
if step_next == Workflow.START:
next = run.steps[0].name
elif step_next == Workflow.PREV:
next = run.steps[-2].name
elif step_next == Workflow.SELF:
next = run.steps[-1].name
elif step_next is None or step_next == Workflow.NEXT:
next = self._find_step(next).next or Workflow.END
else:
next = step_next

await context.emitter.emit(
"success",
{
"run": run.model_copy(),
"state": run.state,
"step": step,
"next": next,
},
)
except Exception as err:
await context.emitter.emit(
"error",
{
"run": run.model_copy(),
"step": next,
"error": err,
},
)
raise err

return run

return RunContext.enter(
RunInstance(emitter=self.emitter),
RunContextInput(params=[state, options], signal=options.signal if options else None),
run_workflow,
)

def _find_step(self, current: K) -> WorkflowState[K]:
index = self.step_names.index(current)
Expand Down
90 changes: 90 additions & 0 deletions python/examples/workflows/emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
from typing import Literal, TypeAlias

from pydantic import BaseModel, ValidationError

from beeai_framework.emitter.emitter import Emitter, EventMeta
from beeai_framework.emitter.types import EmitterOptions
from beeai_framework.workflows.workflow import Workflow, WorkflowError, WorkflowReservedStepName


def print_event(event_data: dict, event_meta: EventMeta) -> None:
"""Process agent events and log appropriately"""

if event_meta.name == "error":
print("Workflow : ", event_data)
elif event_meta.name == "retry":
print("Workflow : ", "retrying the action...")
elif event_meta.name == "update":
print(f"Agent({event_data['update']['key']}) 🤖 : ", event_data["update"]["parsedValue"])
elif event_meta.name == "start":
if event_data:
print(f"Workflow : Starting step: {event_data.get('step')}")
else:
print("Workflow : Starting")
elif event_meta.name == "success":
if isinstance(event_data, dict):
run = event_data.get("run")
print(f"Workflow : Completed step: {run.steps[-1].name}, Result: {run.state.result}")
print(f"Workflow : Next step: {event_data.get('next')}")
else:
print("Workflow : Result: ", event_data.result)
elif event_meta.name == "finish":
print("Workflow : Finished")


async def main() -> None:
# State
class State(BaseModel):
x: int
y: int
abs_repetitions: int | None = None
result: int | None = None

WorkflowStep: TypeAlias = Literal["pre_process", "add_loop", "post_process"]

# Observe the agent
async def observer(emitter: Emitter) -> None:
emitter.on("*.*", print_event, EmitterOptions(match_nested=True))

def pre_process(state: State) -> WorkflowStep:
state.abs_repetitions = abs(state.y)
return "add_loop"

def add_loop(state: State) -> WorkflowStep | WorkflowReservedStepName:
if state.abs_repetitions and state.abs_repetitions > 0:
result = (state.result if state.result is not None else 0) + state.x
abs_repetitions = (state.abs_repetitions if state.abs_repetitions is not None else 0) - 1
print(f"add_loop: intermediate result {result}")
state.abs_repetitions = abs_repetitions
state.result = result
return Workflow.SELF
else:
return "post_process"

def post_process(state: State) -> WorkflowReservedStepName:
if state.y < 0:
result = -(state.result if state.result is not None else 0)
state.result = result
return Workflow.END

try:
multiplication_workflow = Workflow[State, WorkflowStep](name="MultiplicationWorkflow", schema=State)
multiplication_workflow.add_step("pre_process", pre_process)
multiplication_workflow.add_step("add_loop", add_loop)
multiplication_workflow.add_step("post_process", post_process)

response = await multiplication_workflow.run(State(x=8, y=5)).observe(observer)
print(f"result: {response.state.result}")

response = await multiplication_workflow.run(State(x=8, y=-5)).observe(observer)
print(f"result: {response.state.result}")

except WorkflowError as e:
print(e)
except ValidationError as e:
print(e)


if __name__ == "__main__":
asyncio.run(main())