Skip to content

Commit bb1ea74

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Delegate the agent state reset logic to LoopAgent
This is so we don't need to worry about side effect of Loop in all agent type. Custom agent should do the same if there exists loop inside. PiperOrigin-RevId: 818766305
1 parent 214986e commit bb1ea74

File tree

5 files changed

+55
-12
lines changed

5 files changed

+55
-12
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,25 @@ def set_agent_state(
249249
self.end_of_agents.pop(agent_name, None)
250250
self.agent_states.pop(agent_name, None)
251251

252+
def reset_sub_agent_states(
253+
self,
254+
agent_name: str,
255+
) -> None:
256+
"""Resets the state of all sub-agents of the given agent in this invocation.
257+
258+
Args:
259+
agent_name: The name of the agent whose sub-agent states need to be reset.
260+
"""
261+
agent = self.agent.find_agent(agent_name)
262+
if not agent:
263+
return
264+
265+
for sub_agent in agent.sub_agents:
266+
# Reset the sub-agent's state in the context to ensure that each
267+
# sub-agent starts fresh.
268+
self.set_agent_state(sub_agent.name)
269+
self.reset_sub_agent_states(sub_agent.name)
270+
252271
def populate_invocation_agent_states(self) -> None:
253272
"""Populates agent states for the current invocation if it is resumable.
254273

src/google/adk/agents/loop_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ async def _run_async_impl(
9494
ctx.set_agent_state(self.name, agent_state=agent_state)
9595
yield self._create_agent_state_event(ctx)
9696

97-
# Reset the sub-agent's state in the context to ensure that each
98-
# sub-agent starts fresh.
99-
if not is_resuming_at_current_agent:
100-
ctx.set_agent_state(sub_agent.name)
10197
is_resuming_at_current_agent = False
10298

10399
async with Aclosing(sub_agent.run_async(ctx)) as agen:
@@ -114,6 +110,8 @@ async def _run_async_impl(
114110
# Restart from the beginning of the loop.
115111
start_index = 0
116112
times_looped += 1
113+
# Reset the state of all sub-agents in the loop.
114+
ctx.reset_sub_agent_states(self.name)
117115

118116
# If the invocation is paused, we should not yield the end of agent event.
119117
if pause_invocation:

src/google/adk/agents/parallel_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,6 @@ async def _run_async_impl(
187187
agent_runs = []
188188
# Prepare and collect async generators for each sub-agent.
189189
for sub_agent in self.sub_agents:
190-
if agent_state is None:
191-
# Reset sub-agent state to make sure each sub-agent starts fresh.
192-
ctx.set_agent_state(sub_agent.name)
193-
194190
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
195191

196192
# Only include sub-agents that haven't finished in a previous run.

src/google/adk/agents/sequential_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@ async def _run_async_impl(
7373
ctx.set_agent_state(self.name, agent_state=agent_state)
7474
yield self._create_agent_state_event(ctx)
7575

76-
# Reset the sub-agent's state in the context to ensure that each
77-
# sub-agent starts fresh.
78-
ctx.set_agent_state(sub_agent.name)
79-
8076
async with Aclosing(sub_agent.run_async(ctx)) as agen:
8177
async for event in agen:
8278
yield event

tests/unittests/agents/test_invocation_context.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,40 @@ def test_reset_agent_state(self):
390390
assert 'agent1' not in invocation_context.agent_states
391391
assert 'agent1' not in invocation_context.end_of_agents
392392

393+
def test_reset_sub_agent_states(self):
394+
"""Tests that reset_sub_agent_states resets sub-agent states."""
395+
sub_sub_agent_1 = BaseAgent(name='sub_sub_agent_1')
396+
sub_agent_1 = BaseAgent(name='sub_agent_1', sub_agents=[sub_sub_agent_1])
397+
sub_agent_2 = BaseAgent(name='sub_agent_2')
398+
root_agent = BaseAgent(
399+
name='root_agent', sub_agents=[sub_agent_1, sub_agent_2]
400+
)
401+
402+
invocation_context = self._create_test_invocation_context(
403+
ResumabilityConfig(is_resumable=True)
404+
)
405+
invocation_context.agent = root_agent
406+
invocation_context.set_agent_state(
407+
'sub_agent_1', agent_state=BaseAgentState()
408+
)
409+
invocation_context.set_agent_state('sub_agent_2', end_of_agent=True)
410+
invocation_context.set_agent_state(
411+
'sub_sub_agent_1', agent_state=BaseAgentState()
412+
)
413+
414+
assert 'sub_agent_1' in invocation_context.agent_states
415+
assert 'sub_agent_2' in invocation_context.end_of_agents
416+
assert 'sub_sub_agent_1' in invocation_context.agent_states
417+
418+
invocation_context.reset_sub_agent_states('root_agent')
419+
420+
assert 'sub_agent_1' not in invocation_context.agent_states
421+
assert 'sub_agent_1' not in invocation_context.end_of_agents
422+
assert 'sub_agent_2' not in invocation_context.agent_states
423+
assert 'sub_agent_2' not in invocation_context.end_of_agents
424+
assert 'sub_sub_agent_1' not in invocation_context.agent_states
425+
assert 'sub_sub_agent_1' not in invocation_context.end_of_agents
426+
393427

394428
class TestFindMatchingFunctionCall:
395429
"""Test suite for find_matching_function_call."""

0 commit comments

Comments
 (0)