Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import contextlib
import copy
import functools
import inspect
import logging
Expand Down Expand Up @@ -263,8 +264,9 @@ async def execute(
)

# Invoke the handler with the message and context
# Use deepcopy to capture original input state before handler can mutate it
with _framework_event_origin():
invoke_event = ExecutorInvokedEvent(self.id, message)
invoke_event = ExecutorInvokedEvent(self.id, copy.deepcopy(message))
await context.add_event(invoke_event)
try:
await handler(message, context)
Expand All @@ -275,9 +277,11 @@ async def execute(
await context.add_event(failure_event)
raise
with _framework_event_origin():
# Include sent messages as the completion data
# Include sent messages and yielded outputs as the completion data
sent_messages = context.get_sent_messages()
completed_event = ExecutorCompletedEvent(self.id, sent_messages if sent_messages else None)
yielded_outputs = context.get_yielded_outputs()
completion_data = sent_messages + yielded_outputs
completed_event = ExecutorCompletedEvent(self.id, completion_data if completion_data else None)
await context.add_event(completed_event)

def _create_context_for_handler(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import copy
import inspect
import logging
import uuid
Expand Down Expand Up @@ -290,6 +291,9 @@ def __init__(
# Track messages sent via send_message() for ExecutorCompletedEvent
self._sent_messages: list[Any] = []

# Track outputs yielded via yield_output() for ExecutorCompletedEvent
self._yielded_outputs: list[Any] = []

# Store trace contexts and source span IDs for linking (supporting multiple sources)
self._trace_contexts = trace_contexts or []
self._source_span_ids = source_span_ids or []
Expand Down Expand Up @@ -336,6 +340,9 @@ async def yield_output(self, output: T_W_Out) -> None:
output: The output to yield. This must conform to the workflow output type(s)
declared on this context.
"""
# Track yielded output for ExecutorCompletedEvent (deepcopy to capture state at yield time)
self._yielded_outputs.append(copy.deepcopy(output))

with _framework_event_origin():
event = WorkflowOutputEvent(data=output, source_executor_id=self._executor_id)
await self._runner_context.add_event(event)
Expand Down Expand Up @@ -424,6 +431,14 @@ def get_sent_messages(self) -> list[Any]:
"""
return self._sent_messages.copy()

def get_yielded_outputs(self) -> list[Any]:
"""Get all outputs yielded via yield_output() during this handler execution.

Returns:
A list of outputs that were yielded as workflow outputs.
"""
return self._yielded_outputs.copy()

@deprecated(
"Override `on_checkpoint_save()` methods instead. "
"For cross-executor state sharing, use set_shared_state() instead. "
Expand Down
43 changes: 39 additions & 4 deletions python/packages/core/tests/workflow/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import pytest

from agent_framework import (
ChatMessage,
Executor,
ExecutorCompletedEvent,
ExecutorInvokedEvent,
Message,
WorkflowBuilder,
WorkflowContext,
executor,
handler,
)

Expand Down Expand Up @@ -182,8 +184,8 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None:
assert collector_completed.data is None


async def test_executor_completed_event_none_when_no_messages_sent():
"""Test that ExecutorCompletedEvent.data is None when no messages are sent."""
async def test_executor_completed_event_includes_yielded_outputs():
"""Test that ExecutorCompletedEvent.data includes yielded outputs."""
from typing_extensions import Never

from agent_framework import WorkflowOutputEvent
Expand All @@ -201,9 +203,10 @@ async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None:

assert len(completed_events) == 1
assert completed_events[0].executor_id == "yielder"
assert completed_events[0].data is None
# Yielded outputs are now included in ExecutorCompletedEvent.data
assert completed_events[0].data == ["TEST"]

# Verify the output was still yielded correctly
# Verify the output was also yielded as WorkflowOutputEvent
output_events = [e for e in events if isinstance(e, WorkflowOutputEvent)]
assert len(output_events) == 1
assert output_events[0].data == "TEST"
Expand Down Expand Up @@ -261,3 +264,35 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None:
collector_invoked = next(e for e in invoked_events if e.executor_id == "collector")
assert isinstance(collector_invoked.data, Response)
assert collector_invoked.data.results == ["HELLO", "HELLO", "HELLO"]


async def test_executor_invoked_event_data_not_mutated_by_handler():
"""Test that ExecutorInvokedEvent.data captures original input, not mutated input."""

@executor(id="Mutator")
async def mutator(messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None:
# The handler mutates the input list by appending new messages
original_len = len(messages)
messages.append(ChatMessage(role="assistant", text="Added by executor"))
await ctx.send_message(messages)
# Verify mutation happened
assert len(messages) == original_len + 1

workflow = WorkflowBuilder().set_start_executor(mutator).build()

# Run with a single user message
input_messages = [ChatMessage(role="user", text="hello")]
events = await workflow.run(input_messages)

# Find the invoked event for the Mutator executor
invoked_events = [e for e in events if isinstance(e, ExecutorInvokedEvent)]
assert len(invoked_events) == 1
mutator_invoked = invoked_events[0]

# The event data should contain ONLY the original input (1 user message)
assert mutator_invoked.executor_id == "Mutator"
assert len(mutator_invoked.data) == 1, (
f"Expected 1 message (original input), got {len(mutator_invoked.data)}: "
f"{[m.text for m in mutator_invoked.data]}"
)
assert mutator_invoked.data[0].text == "hello"
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def main() -> None:
Input: str: 'HELLO WORLD'
[WORKFLOW OUTPUT] str: 'DLROW OLLEH'
[COMPLETED] reverse_text
Output: list: [str: 'DLROW OLLEH']
"""


Expand Down
Loading