Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/packages/core/agent_framework/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _merge_chat_options(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] | None = None,
top_p: float | None = None,
user: str | None = None,
Expand Down Expand Up @@ -496,7 +496,7 @@ async def get_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down Expand Up @@ -595,7 +595,7 @@ async def get_streaming_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down
6 changes: 6 additions & 0 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,12 @@ async def function_invocation_wrapper(
prepped_messages = prepare_messages(messages)
response: "ChatResponse | None" = None
fcc_messages: "list[ChatMessage]" = []

# If tools are provided but tool_choice is not set, default to "auto" for function invocation
tools = _extract_tools(kwargs)
if tools and kwargs.get("tool_choice") is None:
kwargs["tool_choice"] = "auto"

for attempt_idx in range(config.max_iterations if config.enabled else 0):
fcc_todo = _collect_approval_responses(prepped_messages)
if fcc_todo:
Expand Down
158 changes: 143 additions & 15 deletions python/packages/core/agent_framework/_workflows/_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent:
# so we need to recombine them here to pass the complete tools list to the constructor.
# This makes sure MCP tools are preserved when cloning agents for handoff workflows.
all_tools = list(options.tools) if options.tools else []
if agent._local_mcp_tools:
all_tools.extend(agent._local_mcp_tools)
if agent._local_mcp_tools: # type: ignore
all_tools.extend(agent._local_mcp_tools) # type: ignore

return ChatAgent(
chat_client=agent.chat_client,
Expand Down Expand Up @@ -133,6 +133,14 @@ class _ConversationWithUserInput:
full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc]


@dataclass
class _ConversationForUserInput:
"""Internal message from coordinator to gateway specifying which agent will receive the response."""

conversation: list[ChatMessage]
next_agent_id: str


class _AutoHandoffMiddleware(FunctionMiddleware):
"""Intercept handoff tool invocations and short-circuit execution with synthetic results."""

Expand Down Expand Up @@ -275,6 +283,7 @@ def __init__(
termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]],
id: str,
handoff_tool_targets: Mapping[str, str] | None = None,
return_to_previous: bool = False,
) -> None:
"""Create a coordinator that manages routing between specialists and the user."""
super().__init__(id)
Expand All @@ -284,6 +293,8 @@ def __init__(
self._input_gateway_id = input_gateway_id
self._termination_condition = termination_condition
self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()}
self._return_to_previous = return_to_previous
self._current_agent_id: str | None = None # Track the current agent handling conversation

def _get_author_name(self) -> str:
"""Get the coordinator name for orchestrator-generated messages."""
Expand All @@ -293,7 +304,7 @@ def _get_author_name(self) -> str:
async def handle_agent_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage]],
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
) -> None:
"""Process an agent's response and determine whether to route, request input, or terminate."""
# Hydrate coordinator state (and detect new run) using checkpointable executor state
Expand Down Expand Up @@ -329,6 +340,9 @@ async def handle_agent_response(
# Check for handoff from ANY agent (starting agent or specialist)
target = self._resolve_specialist(response.agent_run_response, conversation)
if target is not None:
# Update current agent when handoff occurs
self._current_agent_id = target
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
await self._persist_state(ctx)
# Clean tool-related content before sending to next agent
cleaned = clean_conversation_for_handoff(conversation)
Expand All @@ -340,10 +354,15 @@ async def handle_agent_response(
if not is_starting_agent and source not in self._specialist_ids:
raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.")

# Update current agent when they respond without handoff
self._current_agent_id = source
logger.info(
f"Agent '{source}' responded without handoff. "
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
)
await self._persist_state(ctx)

if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
# Clean the output conversation for display
cleaned_output = clean_conversation_for_handoff(conversation)
await ctx.yield_output(cleaned_output)
Expand All @@ -352,7 +371,13 @@ async def handle_agent_response(
# Clean conversation before sending to gateway for user input request
# This removes tool messages that shouldn't be shown to users
cleaned_for_display = clean_conversation_for_handoff(conversation)
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)

# The awaiting_agent_id is the agent that just responded and is awaiting user input
# This is the source of the current response
next_agent_id = source

message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id)
await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type]

@handler
async def handle_user_input(
Expand All @@ -367,14 +392,26 @@ async def handle_user_input(

# Check termination before sending to agent
if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
await ctx.yield_output(list(self._conversation))
return

# Clean before sending to starting agent
# Determine routing target based on return-to-previous setting
target_agent_id = self._starting_agent_id
if self._return_to_previous and self._current_agent_id:
# Route back to the current agent that's handling the conversation
target_agent_id = self._current_agent_id
logger.info(
f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' "
f"(bypassing coordinator '{self._starting_agent_id}')"
)
else:
logger.info(f"Routing user input to coordinator '{target_agent_id}'")
# Note: Stack is only used for specialist-to-specialist handoffs, not user input routing

# Clean before sending to target agent
cleaned = clean_conversation_for_handoff(self._conversation)
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
await ctx.send_message(request, target_id=self._starting_agent_id)
await ctx.send_message(request, target_id=target_agent_id)

def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None:
"""Resolve the specialist executor id requested by the agent response, if any."""
Expand Down Expand Up @@ -444,22 +481,27 @@ async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.

Handoff has no additional metadata beyond base conversation state.
Includes the current agent for return-to-previous routing.

Returns:
Empty dict (no pattern-specific state)
Dict containing current agent if return-to-previous is enabled
"""
if self._return_to_previous:
return {
"current_agent_id": self._current_agent_id,
}
return {}

def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.

Handoff has no additional metadata beyond base conversation state.
Restores the current agent for return-to-previous routing.

Args:
metadata: Pattern-specific state dict (ignored)
metadata: Pattern-specific state dict
"""
pass
if self._return_to_previous and "current_agent_id" in metadata:
self._current_agent_id = metadata["current_agent_id"]

def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
"""Rehydrate the coordinator's conversation history from checkpointed state.
Expand Down Expand Up @@ -507,8 +549,21 @@ def __init__(
self._prompt = prompt or "Provide your next input for the conversation."

@handler
async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None:
"""Emit a `HandoffUserInputRequest` capturing the conversation snapshot."""
if not message.conversation:
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
request = HandoffUserInputRequest(
conversation=list(message.conversation),
awaiting_agent_id=message.next_agent_id,
prompt=self._prompt,
source_executor_id=self.id,
)
await ctx.request_info(request, object)

@handler
async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
"""Legacy handler for backward compatibility - emit user input request with starting agent."""
if not conversation:
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
request = HandoffUserInputRequest(
Expand Down Expand Up @@ -558,7 +613,7 @@ def _as_user_messages(payload: Any) -> list[ChatMessage]:


def _default_termination_condition(conversation: list[ChatMessage]) -> bool:
"""Default termination: stop after 10 user messages to prevent infinite loops."""
"""Default termination: stop after 10 user messages."""
user_message_count = sum(1 for msg in conversation if msg.role == Role.USER)
return user_message_count >= 10

Expand Down Expand Up @@ -743,6 +798,7 @@ def __init__(
)
self._auto_register_handoff_tools: bool = True
self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids]
self._return_to_previous: bool = False

if participants:
self.participants(participants)
Expand Down Expand Up @@ -1198,6 +1254,77 @@ async def check_termination(conv: list[ChatMessage]) -> bool:
self._termination_condition = condition
return self

def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder":
"""Enable direct return to the current agent after user input, bypassing the coordinator.

When enabled, after a specialist responds without requesting another handoff, user input
routes directly back to that same specialist instead of always routing back to the
coordinator agent for re-evaluation.

This is useful when a specialist needs multiple turns with the user to gather information
or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context.

Flow Comparison:

**Default (disabled):**
User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ...

**With return_to_previous (enabled):**
User -> Coordinator -> Specialist -> User -> Specialist -> ...

Args:
enabled: Whether to enable return-to-previous routing. Default is True.

Returns:
Self for method chaining.

Example:

.. code-block:: python

workflow = (
HandoffBuilder(participants=[triage, technical_support, billing])
.set_coordinator("triage")
.add_handoff(triage, [technical_support, billing])
.enable_return_to_previous() # Enable direct return routing
.build()
)

# Flow: User asks question
# -> Triage routes to Technical Support
# -> Technical Support asks clarifying question
# -> User provides more info
# -> Routes back to Technical Support (not Triage)
# -> Technical Support continues helping

Multi-tier handoff example:

.. code-block:: python

workflow = (
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
.set_coordinator("triage")
.add_handoff(triage, [specialist_a, specialist_b])
.add_handoff(specialist_a, specialist_b)
.enable_return_to_previous()
.build()
)

# Flow: User asks question
# -> Triage routes to Specialist A
# -> Specialist A hands off to Specialist B
# -> Specialist B asks clarifying question
# -> User provides more info
# -> Routes back to Specialist B (who is currently handling the conversation)

Note:
This feature routes to whichever agent most recently responded, whether that's
the coordinator or a specialist. The conversation continues with that agent until
they either hand off to another agent or the termination condition is met.
"""
self._return_to_previous = enabled
return self

def build(self) -> Workflow:
"""Construct the final Workflow instance from the configured builder.

Expand Down Expand Up @@ -1326,6 +1453,7 @@ def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor:
termination_condition=self._termination_condition,
id="handoff-coordinator",
handoff_tool_targets=handoff_tool_targets,
return_to_previous=self._return_to_previous,
)

wiring = _GroupChatConfig(
Expand Down
Loading