diff --git a/python/packages/core/agent_framework/_workflows/_concurrent.py b/python/packages/core/agent_framework/_workflows/_concurrent.py index f6a7b09e60..a6fcaa1a3e 100644 --- a/python/packages/core/agent_framework/_workflows/_concurrent.py +++ b/python/packages/core/agent_framework/_workflows/_concurrent.py @@ -3,6 +3,7 @@ import asyncio import inspect import logging +import uuid from collections.abc import Callable, Sequence from typing import Any @@ -189,8 +190,11 @@ class ConcurrentBuilder: r"""High-level builder for concurrent agent workflows. - `participants([...])` accepts a list of AgentProtocol (recommended) or Executor. + - `register_participants([...])` accepts a list of factories for AgentProtocol (recommended) + or Executor factories - `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator. - - `with_custom_aggregator(...)` overrides the default aggregator with an Executor or callback. + - `with_aggregator(...)` overrides the default aggregator with an Executor or callback. + - `register_aggregator(...)` accepts a factory for an Executor as custom aggregator. Usage: @@ -201,14 +205,33 @@ class ConcurrentBuilder: # Minimal: use default aggregator (returns list[ChatMessage]) workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).build() + # With agent factories + workflow = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build() + # Custom aggregator via callback (sync or async). The callback receives # list[AgentExecutorResponse] and its return value becomes the workflow's output. - def summarize(results): + def summarize(results: list[AgentExecutorResponse]) -> str: return " | ".join(r.agent_run_response.messages[-1].text for r in results) - workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_custom_aggregator(summarize).build() + workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_aggregator(summarize).build() + + + # Custom aggregator via a factory + class MyAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + workflow = ( + ConcurrentBuilder() + .register_participants([create_agent1, create_agent2, create_agent3]) + .register_aggregator(lambda: MyAggregator(id="my_aggregator")) + .build() + ) + # Enable checkpoint persistence so runs can resume workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_checkpointing(storage).build() @@ -219,10 +242,67 @@ def summarize(results): def __init__(self) -> None: self._participants: list[AgentProtocol | Executor] = [] + self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = [] self._aggregator: Executor | None = None + self._aggregator_factory: Callable[[], Executor] | None = None self._checkpoint_storage: CheckpointStorage | None = None self._request_info_enabled: bool = False + def register_participants( + self, + participant_factories: Sequence[Callable[[], AgentProtocol | Executor]], + ) -> "ConcurrentBuilder": + r"""Define the parallel participants for this concurrent workflow. + + Accepts factories (callables) that return AgentProtocol instances (e.g., created + by a chat client) or Executor instances. Each participant created by a factory + is wired as a parallel branch using fan-out edges from an internal dispatcher. + + Args: + participant_factories: Sequence of callables returning AgentProtocol or Executor instances + + Raises: + ValueError: if `participant_factories` is empty or `.participants()` + or `.register_participants()` were already called + + Example: + + .. code-block:: python + + def create_researcher() -> ChatAgent: + return ... + + + def create_marketer() -> ChatAgent: + return ... + + + def create_legal() -> ChatAgent: + return ... + + + class MyCustomExecutor(Executor): ... + + + wf = ConcurrentBuilder().register_participants([create_researcher, create_marketer, create_legal]).build() + + # Mixing agent(s) and executor(s) is supported + wf2 = ConcurrentBuilder().register_participants([create_researcher, MyCustomExecutor]).build() + """ + if self._participants: + raise ValueError( + "Cannot mix .participants([...]) and .register_participants() in the same builder instance." + ) + + if self._participant_factories: + raise ValueError("register_participants() has already been called on this builder instance.") + + if not participant_factories: + raise ValueError("participant_factories cannot be empty") + + self._participant_factories = list(participant_factories) + return self + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder": r"""Define the parallel participants for this concurrent workflow. @@ -230,8 +310,12 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con instances. Each participant is wired as a parallel branch using fan-out edges from an internal dispatcher. + Args: + participants: Sequence of AgentProtocol or Executor instances + Raises: - ValueError: if `participants` is empty or contains duplicates + ValueError: if `participants` is empty, contains duplicates, or `.register_participants()` + or `.participants()` were already called TypeError: if any entry is not AgentProtocol or Executor Example: @@ -243,6 +327,14 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con # Mixing agent(s) and executor(s) is supported wf2 = ConcurrentBuilder().participants([researcher_agent, my_custom_executor]).build() """ + if self._participant_factories: + raise ValueError( + "Cannot mix .participants([...]) and .register_participants() in the same builder instance." + ) + + if self._participants: + raise ValueError("participants() has already been called on this builder instance.") + if not participants: raise ValueError("participants cannot be empty") @@ -265,38 +357,107 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con self._participants = list(participants) return self - def with_aggregator(self, aggregator: Executor | Callable[..., Any]) -> "ConcurrentBuilder": - r"""Override the default aggregator with an Executor or a callback. + def register_aggregator(self, aggregator_factory: Callable[[], Executor]) -> "ConcurrentBuilder": + r"""Define a custom aggregator for this concurrent workflow. + + Accepts a factory (callable) that returns an Executor instance. The executor + should handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`. + + Args: + aggregator_factory: Callable that returns an Executor instance + + Example: + .. code-block:: python + + class MyCustomExecutor(Executor): ... + - - Executor: must handle `list[AgentExecutorResponse]` and - yield output using `ctx.yield_output(...)` and add a - output and the workflow becomes idle. + wf = ( + ConcurrentBuilder() + .register_participants([create_researcher, create_marketer, create_legal]) + .register_aggregator(lambda: MyCustomExecutor(id="my_aggregator")) + .build() + ) + """ + if self._aggregator is not None: + raise ValueError( + "Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance." + ) + + if self._aggregator_factory is not None: + raise ValueError("register_aggregator() has already been called on this builder instance.") + + self._aggregator_factory = aggregator_factory + return self + + def with_aggregator( + self, + aggregator: Executor + | Callable[[list[AgentExecutorResponse]], Any] + | Callable[[list[AgentExecutorResponse], WorkflowContext[Never, Any]], Any], + ) -> "ConcurrentBuilder": + r"""Override the default aggregator with an executor, an executor factory, or a callback. + + - Executor: must handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)` - Callback: sync or async callable with one of the signatures: `(results: list[AgentExecutorResponse]) -> Any | None` or `(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None`. If the callback returns a non-None value, it becomes the workflow's output. + Args: + aggregator: Executor instance, or callback function + Example: .. code-block:: python + # Executor-based aggregator + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(CustomAggregator()).build() + # Callback-based aggregator (string result) - async def summarize(results): + async def summarize(results: list[AgentExecutorResponse]) -> str: return " | ".join(r.agent_run_response.messages[-1].text for r in results) - wf = ConcurrentBuilder().participants([a1, a2, a3]).with_custom_aggregator(summarize).build() + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() + + + # Callback-based aggregator (yield result) + async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() """ + if self._aggregator_factory is not None: + raise ValueError( + "Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance." + ) + + if self._aggregator is not None: + raise ValueError("with_aggregator() has already been called on this builder instance.") + if isinstance(aggregator, Executor): self._aggregator = aggregator elif callable(aggregator): self._aggregator = _CallbackAggregator(aggregator) else: raise TypeError("aggregator must be an Executor or a callable") + return self def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "ConcurrentBuilder": - """Enable checkpoint persistence using the provided storage backend.""" + """Enable checkpoint persistence using the provided storage backend. + + Args: + checkpoint_storage: CheckpointStorage instance for persisting workflow state + """ self._checkpoint_storage = checkpoint_storage return self @@ -329,7 +490,7 @@ def build(self) -> Workflow: before sending the outputs to the aggregator - Aggregator yields output and the workflow becomes idle. The output is either: - list[ChatMessage] (default aggregator: one user + one assistant per agent) - - custom payload from the provided callback/executor + - custom payload from the provided aggregator Returns: Workflow: a ready-to-run workflow instance @@ -343,26 +504,69 @@ def build(self) -> Workflow: workflow = ConcurrentBuilder().participants([agent1, agent2]).build() """ - if not self._participants: - raise ValueError("No participants provided. Call .participants([...]) first.") + if not self._participants and not self._participant_factories: + raise ValueError( + "No participants provided. Call .participants([...]) or .register_participants([...]) first." + ) + # Internal nodes dispatcher = _DispatchToAllParticipants(id="dispatcher") - aggregator = self._aggregator or _AggregateAgentConversations(id="aggregator") + aggregator = ( + self._aggregator + if self._aggregator is not None + else ( + self._aggregator_factory() + if self._aggregator_factory is not None + else _AggregateAgentConversations(id="aggregator") + ) + ) builder = WorkflowBuilder() - builder.set_start_executor(dispatcher) - builder.add_fan_out_edges(dispatcher, list(self._participants)) - - if self._request_info_enabled: - # Insert interceptor between fan-in and aggregator - # participants -> fan-in -> interceptor -> aggregator - request_info_interceptor = RequestInfoInterceptor(executor_id="request_info") - builder.add_fan_in_edges(list(self._participants), request_info_interceptor) - builder.add_edge(request_info_interceptor, aggregator) + if self._participant_factories: + # Register executors/agents to avoid warnings from the workflow builder + # if factories are provided instead of direct instances. This doesn't + # break the factory pattern since the concurrent builder still creates + # new instances per workflow build. + factory_names: list[str] = [] + for factory in self._participant_factories: + factory_name = uuid.uuid4().hex + factory_names.append(factory_name) + instance = factory() + if isinstance(instance, Executor): + builder.register_executor(lambda executor=instance: executor, name=factory_name) # type: ignore[misc] + else: + builder.register_agent(lambda agent=instance: agent, name=factory_name) # type: ignore[misc] + # Register the dispatcher and the aggregator + builder.register_executor(lambda: dispatcher, name="dispatcher") + builder.register_executor(lambda: aggregator, name="aggregator") + + builder.set_start_executor("dispatcher") + builder.add_fan_out_edges("dispatcher", factory_names) + if self._request_info_enabled: + # Insert interceptor between fan-in and aggregator + # participants -> fan-in -> interceptor -> aggregator + builder.register_executor( + lambda: RequestInfoInterceptor(executor_id="request_info"), + name="request_info_interceptor", + ) + builder.add_fan_in_edges(factory_names, "request_info_interceptor") + builder.add_edge("request_info_interceptor", "aggregator") + else: + # Direct fan-in to aggregator + builder.add_fan_in_edges(factory_names, "aggregator") else: - # Direct fan-in to aggregator - builder.add_fan_in_edges(list(self._participants), aggregator) - + builder.set_start_executor(dispatcher) + builder.add_fan_out_edges(dispatcher, self._participants) + + if self._request_info_enabled: + # Insert interceptor between fan-in and aggregator + # participants -> fan-in -> interceptor -> aggregator + request_info_interceptor = RequestInfoInterceptor(executor_id="request_info") + builder.add_fan_in_edges(self._participants, request_info_interceptor) + builder.add_edge(request_info_interceptor, aggregator) + else: + # Direct fan-in to aggregator + builder.add_fan_in_edges(self._participants, aggregator) if self._checkpoint_storage is not None: builder = builder.with_checkpointing(self._checkpoint_storage) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 26cd0213e4..5bf36b6ccd 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -374,7 +374,7 @@ def register_agent( ) """ if name in self._executor_registry: - raise ValueError(f"An executor factory with the name '{name}' is already registered.") + raise ValueError(f"An agent factory with the name '{name}' is already registered.") def wrapped_factory() -> AgentExecutor: agent = factory_func() @@ -1148,21 +1148,29 @@ def _resolve_edge_registry(self) -> tuple[Executor, list[Executor], list[EdgeGro if isinstance(self._start_executor, Executor): start_executor = self._start_executor - executors: dict[str, Executor] = {} + # Maps registered factory names to created executor instances for edge resolution + factory_name_to_instance: dict[str, Executor] = {} + # Maps executor IDs to created executor instances to prevent duplicates + executor_id_to_instance: dict[str, Executor] = {} deferred_edge_groups: list[EdgeGroup] = [] for name, exec_factory in self._executor_registry.items(): instance = exec_factory() + if instance.id in executor_id_to_instance: + raise ValueError(f"Executor with ID '{instance.id}' has already been created.") + executor_id_to_instance[instance.id] = instance + if isinstance(self._start_executor, str) and name == self._start_executor: start_executor = instance + # All executors will get their own internal edge group for receiving system messages deferred_edge_groups.append(InternalEdgeGroup(instance.id)) # type: ignore[call-arg] - executors[name] = instance + factory_name_to_instance[name] = instance def _get_executor(name: str) -> Executor: """Helper to get executor by the registered name. Raises if not found.""" - if name not in executors: - raise ValueError(f"Executor with name '{name}' has not been registered.") - return executors[name] + if name not in factory_name_to_instance: + raise ValueError(f"Factory '{name}' has not been registered.") + return factory_name_to_instance[name] for registration in self._edge_registry: match registration: @@ -1179,7 +1187,7 @@ def _get_executor(name: str) -> Executor: cases_converted: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = [] for case in cases: if not isinstance(case.target, str): - raise ValueError("Switch case target must be a registered executor name (str) if deferred.") + raise ValueError("Switch case target must be a registered factory name (str) if deferred.") target_exec = _get_executor(case.target) if isinstance(case, Default): cases_converted.append(SwitchCaseEdgeGroupDefault(target_id=target_exec.id)) @@ -1201,7 +1209,7 @@ def _get_executor(name: str) -> Executor: if start_executor is None: raise ValueError("Failed to resolve starting executor from registered factories.") - return start_executor, list(executors.values()), deferred_edge_groups + return start_executor, list(executor_id_to_instance.values()), deferred_edge_groups def build(self) -> Workflow: """Build and return the constructed workflow. diff --git a/python/packages/core/tests/workflow/test_concurrent.py b/python/packages/core/tests/workflow/test_concurrent.py index db70be3f38..66cc8cfc68 100644 --- a/python/packages/core/tests/workflow/test_concurrent.py +++ b/python/packages/core/tests/workflow/test_concurrent.py @@ -3,6 +3,7 @@ from typing import Any, cast import pytest +from typing_extensions import Never from agent_framework import ( AgentExecutorRequest, @@ -52,6 +53,55 @@ def test_concurrent_builder_rejects_duplicate_executors() -> None: ConcurrentBuilder().participants([a, b]) +def test_concurrent_builder_rejects_duplicate_executors_from_factories() -> None: + """Test that duplicate executor IDs from factories are detected at build time.""" + + def create_dup1() -> Executor: + return _FakeAgentExec("dup", "A") + + def create_dup2() -> Executor: + return _FakeAgentExec("dup", "B") # same executor id + + builder = ConcurrentBuilder().register_participants([create_dup1, create_dup2]) + with pytest.raises(ValueError, match="Executor with ID 'dup' has already been created."): + builder.build() + + +def test_concurrent_builder_rejects_mixed_participants_and_factories() -> None: + """Test that mixing .participants() and .register_participants() raises an error.""" + # Case 1: participants first, then register_participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + ( + ConcurrentBuilder() + .participants([_FakeAgentExec("a", "A")]) + .register_participants([lambda: _FakeAgentExec("b", "B")]) + ) + + # Case 2: register_participants first, then participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + ( + ConcurrentBuilder() + .register_participants([lambda: _FakeAgentExec("a", "A")]) + .participants([_FakeAgentExec("b", "B")]) + ) + + +def test_concurrent_builder_rejects_multiple_calls_to_participants() -> None: + """Test that multiple calls to .participants() raises an error.""" + with pytest.raises(ValueError, match=r"participants\(\) has already been called"): + (ConcurrentBuilder().participants([_FakeAgentExec("a", "A")]).participants([_FakeAgentExec("b", "B")])) + + +def test_concurrent_builder_rejects_multiple_calls_to_register_participants() -> None: + """Test that multiple calls to .register_participants() raises an error.""" + with pytest.raises(ValueError, match=r"register_participants\(\) has already been called"): + ( + ConcurrentBuilder() + .register_participants([lambda: _FakeAgentExec("a", "A")]) + .register_participants([lambda: _FakeAgentExec("b", "B")]) + ) + + async def test_concurrent_default_aggregator_emits_single_user_and_assistants() -> None: # Three synthetic agent executors e1 = _FakeAgentExec("agentA", "Alpha") @@ -159,6 +209,138 @@ def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[over assert aggregator.id == "summarize" +async def test_concurrent_with_aggregator_executor_instance() -> None: + """Test with_aggregator using an Executor instance (not factory).""" + + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" & ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + aggregator_instance = CustomAggregator(id="instance_aggregator") + wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(aggregator_instance).build() + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: instance test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One & Two" + + +async def test_concurrent_with_aggregator_executor_factory() -> None: + """Test with_aggregator using an Executor factory.""" + + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" | ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + wf = ( + ConcurrentBuilder() + .participants([e1, e2]) + .register_aggregator(lambda: CustomAggregator(id="custom_aggregator")) + .build() + ) + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: factory test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One | Two" + + +async def test_concurrent_with_aggregator_executor_factory_with_default_id() -> None: + """Test with_aggregator using an Executor class directly as factory (with default __init__ parameters).""" + + class CustomAggregator(Executor): + def __init__(self, id: str = "default_aggregator") -> None: + super().__init__(id) + + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" | ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + wf = ConcurrentBuilder().participants([e1, e2]).register_aggregator(CustomAggregator).build() + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: factory test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One | Two" + + +def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None: + """Test that multiple calls to .with_aggregator() raises an error.""" + + def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override] + return str(len(results)) + + with pytest.raises(ValueError, match=r"with_aggregator\(\) has already been called"): + (ConcurrentBuilder().with_aggregator(summarize).with_aggregator(summarize)) + + +def test_concurrent_builder_rejects_multiple_calls_to_register_aggregator() -> None: + """Test that multiple calls to .register_aggregator() raises an error.""" + + class CustomAggregator(Executor): + pass + + with pytest.raises(ValueError, match=r"register_aggregator\(\) has already been called"): + ( + ConcurrentBuilder() + .register_aggregator(lambda: CustomAggregator(id="agg1")) + .register_aggregator(lambda: CustomAggregator(id="agg2")) + ) + + async def test_concurrent_checkpoint_resume_round_trip() -> None: storage = InMemoryCheckpointStorage() @@ -278,3 +460,92 @@ async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None: assert len(runtime_checkpoints) > 0, "Runtime storage should have checkpoints" assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" + + +def test_concurrent_builder_rejects_empty_participant_factories() -> None: + with pytest.raises(ValueError): + ConcurrentBuilder().register_participants([]) + + +async def test_concurrent_builder_reusable_after_build_with_participants() -> None: + """Test that the builder can be reused to build multiple identical workflows with participants().""" + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + builder = ConcurrentBuilder().participants([e1, e2]) + + builder.build() + + assert builder._participants[0] is e1 # type: ignore + assert builder._participants[1] is e2 # type: ignore + assert builder._participant_factories == [] # type: ignore + + +async def test_concurrent_builder_reusable_after_build_with_factories() -> None: + """Test that the builder can be reused to build multiple workflows with register_participants().""" + call_count = 0 + + def create_agent_executor_a() -> Executor: + nonlocal call_count + call_count += 1 + return _FakeAgentExec("agentA", "One") + + def create_agent_executor_b() -> Executor: + nonlocal call_count + call_count += 1 + return _FakeAgentExec("agentB", "Two") + + builder = ConcurrentBuilder().register_participants([create_agent_executor_a, create_agent_executor_b]) + + # Build the first workflow + wf1 = builder.build() + + assert builder._participants == [] # type: ignore + assert len(builder._participant_factories) == 2 # type: ignore + assert call_count == 2 + + # Build the second workflow + wf2 = builder.build() + assert call_count == 4 + + # Verify that the two workflows have different executor instances + assert wf1.executors["agentA"] is not wf2.executors["agentA"] + assert wf1.executors["agentB"] is not wf2.executors["agentB"] + + +async def test_concurrent_with_register_participants() -> None: + """Test workflow creation using register_participants with factories.""" + + def create_agent1() -> Executor: + return _FakeAgentExec("agentA", "Alpha") + + def create_agent2() -> Executor: + return _FakeAgentExec("agentB", "Beta") + + def create_agent3() -> Executor: + return _FakeAgentExec("agentC", "Gamma") + + wf = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build() + + completed = False + output: list[ChatMessage] | None = None + async for ev in wf.run_stream("test prompt"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(list[ChatMessage], ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + messages: list[ChatMessage] = output + + # Expect one user message + one assistant message per participant + assert len(messages) == 1 + 3 + assert messages[0].role == Role.USER + assert "test prompt" in messages[0].text + + assistant_texts = {m.text for m in messages[1:]} + assert assistant_texts == {"Alpha", "Beta", "Gamma"} + assert all(m.role == Role.ASSISTANT for m in messages[1:]) diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index a037bf51b6..b281edee34 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -293,6 +293,20 @@ def test_register_duplicate_name_raises_error(): builder.register_executor(lambda: MockExecutor(id="executor_2"), name="MyExecutor") +def test_register_duplicate_id_raises_error(): + """Test that registering duplicate id raises an error.""" + builder = WorkflowBuilder() + + # Register first executor + builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor1") + builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor2") + builder.set_start_executor("MyExecutor1") + + # Registering second executor with same ID should raise ValueError + with pytest.raises(ValueError, match="Executor with ID 'executor' has already been created."): + builder.build() + + def test_register_agent_basic(): """Test basic agent registration with lazy initialization.""" builder = WorkflowBuilder() diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index e1e18eab91..53dce73815 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -110,6 +110,7 @@ For additional observability samples in Agent Framework, see the [observability | Concurrent Orchestration (Default Aggregator) | [orchestration/concurrent_agents.py](./orchestration/concurrent_agents.py) | Fan-out to multiple agents; fan-in with default aggregator returning combined ChatMessages | | Concurrent Orchestration (Custom Aggregator) | [orchestration/concurrent_custom_aggregator.py](./orchestration/concurrent_custom_aggregator.py) | Override aggregator via callback; summarize results with an LLM | | Concurrent Orchestration (Custom Agent Executors) | [orchestration/concurrent_custom_agent_executors.py](./orchestration/concurrent_custom_agent_executors.py) | Child executors own ChatAgents; concurrent fan-out/fan-in via ConcurrentBuilder | +| Concurrent Orchestration (Participant Factory) | [orchestration/concurrent_participant_factory.py](./orchestration/concurrent_participant_factory.py) | Use participant factories for state isolation between workflow instances | | Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `set_manager()` to select next speaker | | Group Chat Philosophical Debate | [orchestration/group_chat_philosophical_debate.py](./orchestration/group_chat_philosophical_debate.py) | Agent manager moderates long-form, multi-round debate across diverse participants | | Group Chat with Simple Function Selector | [orchestration/group_chat_simple_selector.py](./orchestration/group_chat_simple_selector.py) | Group chat with a simple function selector for next speaker | diff --git a/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py b/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py index 4ad8c9fcb3..44f71ba7bc 100644 --- a/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py +++ b/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py @@ -17,7 +17,7 @@ The workflow completes when all participants become idle. Demonstrates: -- ConcurrentBuilder().participants([...]).with_custom_aggregator(callback) +- ConcurrentBuilder().participants([...]).with_aggregator(callback) - Fan-out to agents and fan-in at an aggregator - Aggregation implemented via an LLM call (chat_client.get_response) - Workflow output yielded with the synthesized summary string diff --git a/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py b/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py new file mode 100644 index 0000000000..435e59b2ba --- /dev/null +++ b/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any, Never + +from agent_framework import ( + ChatAgent, + ChatMessage, + ConcurrentBuilder, + Executor, + Role, + Workflow, + WorkflowContext, + handler, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential + +""" +Sample: Concurrent Orchestration with participant factories and Custom Aggregator + +Build a concurrent workflow with ConcurrentBuilder that fans out one prompt to +multiple domain agents and fans in their responses. + +Override the default aggregator with a custom Executor class that uses +AzureOpenAIChatClient.get_response() to synthesize a concise, consolidated summary +from the experts' outputs. + +All participants and the aggregator are created via factory functions that return +their respective ChatAgent or Executor instances. + +Using participant factories allows you to set up proper state isolation between workflow +instances created by the same builder. This is particularly useful when you need to handle +requests or tasks in parallel with stateful participants. + +Demonstrates: +- ConcurrentBuilder().register_participants([...]).with_aggregator(callback) +- Fan-out to agents and fan-in at an aggregator +- Aggregation implemented via an LLM call (chat_client.get_response) +- Workflow output yielded with the synthesized summary string + +Prerequisites: +- Azure OpenAI configured for AzureOpenAIChatClient (az login + required env vars) +""" + + +def create_researcher() -> ChatAgent: + """Factory function to create a researcher agent instance.""" + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + instructions=( + "You're an expert market and product researcher. Given a prompt, provide concise, factual insights," + " opportunities, and risks." + ), + name="researcher", + ) + + +def create_marketer() -> ChatAgent: + """Factory function to create a marketer agent instance.""" + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + instructions=( + "You're a creative marketing strategist. Craft compelling value propositions and target messaging" + " aligned to the prompt." + ), + name="marketer", + ) + + +def create_legal() -> ChatAgent: + """Factory function to create a legal/compliance agent instance.""" + return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( + instructions=( + "You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns" + " based on the prompt." + ), + name="legal", + ) + + +class SummarizationExecutor(Executor): + """Custom aggregator executor that synthesizes expert outputs into a concise summary.""" + + def __init__(self) -> None: + super().__init__(id="summarization_executor") + self.chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + @handler + async def summarize_results(self, results: list[Any], ctx: WorkflowContext[Never, str]) -> None: + expert_sections: list[str] = [] + for r in results: + try: + messages = getattr(r.agent_run_response, "messages", []) + final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)" + expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}:\n{final_text}") + except Exception as e: + expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}: (error: {type(e).__name__}: {e})") + + # Ask the model to synthesize a concise summary of the experts' outputs + system_msg = ChatMessage( + Role.SYSTEM, + text=( + "You are a helpful assistant that consolidates multiple domain expert outputs " + "into one cohesive, concise summary with clear takeaways. Keep it under 200 words." + ), + ) + user_msg = ChatMessage(Role.USER, text="\n\n".join(expert_sections)) + + response = await self.chat_client.get_response([system_msg, user_msg]) + + await ctx.yield_output(response.messages[-1].text if response.messages else "") + + +async def run_workflow(workflow: Workflow, query: str) -> None: + events = await workflow.run(query) + outputs = events.get_outputs() + + if outputs: + print(outputs[0]) # Get the first (and typically only) output + else: + raise RuntimeError("No outputs received from the workflow.") + + +async def main() -> None: + # Create a concurrent builder with participant factories and a custom aggregator + # - register_participants([...]) accepts factory functions that return + # AgentProtocol (agents) or Executor instances. + # - register_aggregator(...) takes a factory function that returns an Executor instance. + concurrent_builder = ( + ConcurrentBuilder() + .register_participants([create_researcher, create_marketer, create_legal]) + .register_aggregator(SummarizationExecutor) + ) + + # Build workflow_a + workflow_a = concurrent_builder.build() + + # Run workflow_a + # Context is maintained across runs + print("=== First Run on workflow_a ===") + await run_workflow(workflow_a, "We are launching a new budget-friendly electric bike for urban commuters.") + print("\n=== Second Run on workflow_a ===") + await run_workflow(workflow_a, "Refine your response to focus on the California market.") + + # Build workflow_b + # This will create new instances of all participants and the aggregator + # The agents will also get new threads + workflow_b = concurrent_builder.build() + # Run workflow_b + # Context is not maintained across instances + # Should not expect mentions of electric bikes in the results + print("\n=== First Run on workflow_b ===") + await run_workflow(workflow_b, "Refine your response to focus on the California market.") + + """ + Sample Output: + + === First Run on workflow_a === + The budget-friendly electric bike market is poised for significant growth, driven by urbanization, ... + + === Second Run on workflow_a === + Launching a budget-friendly electric bike in California presents significant opportunities, driven ... + + === First Run on workflow_b === + To successfully penetrate the California market, consider these tailored strategies focused on ... + """ + + +if __name__ == "__main__": + asyncio.run(main())