Skip to content

Commit 236f562

Browse files
GWealecopybara-github
authored andcommitted
fix: Load agent/app before creating session
This change loads the agent or app from the specified directory before creating the session. This allows using the correct application name (from the `App` object if applicable) when initializing the session, rather than always defaulting to the folder name. The variable `root_agent` is also renamed to `agent_or_app` to better reflect that it can be either an Agent or an App Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 833839070
1 parent 4dd28a3 commit 236f562

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

src/google/adk/cli/cli.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,22 @@ async def run_cli(
155155
credential_service = InMemoryCredentialService()
156156

157157
user_id = 'test_user'
158-
session = await session_service.create_session(
159-
app_name=agent_folder_name, user_id=user_id
160-
)
161-
root_agent = AgentLoader(agents_dir=agent_parent_dir).load_agent(
158+
agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent(
162159
agent_folder_name
163160
)
161+
session_app_name = (
162+
agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name
163+
)
164+
session = await session_service.create_session(
165+
app_name=session_app_name, user_id=user_id
166+
)
164167
if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'):
165168
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
166169
if input_file:
167170
session = await run_input_file(
168-
app_name=agent_folder_name,
171+
app_name=session_app_name,
169172
user_id=user_id,
170-
agent_or_app=root_agent,
173+
agent_or_app=agent_or_app,
171174
artifact_service=artifact_service,
172175
session_service=session_service,
173176
credential_service=credential_service,
@@ -186,16 +189,16 @@ async def run_cli(
186189
click.echo(f'[{event.author}]: {content.parts[0].text}')
187190

188191
await run_interactively(
189-
root_agent,
192+
agent_or_app,
190193
artifact_service,
191194
session,
192195
session_service,
193196
credential_service,
194197
)
195198
else:
196-
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
199+
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
197200
await run_interactively(
198-
root_agent,
201+
agent_or_app,
199202
artifact_service,
200203
session,
201204
session_service,

tests/unittests/cli/utils/test_cli.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import click
2929
from google.adk.agents.base_agent import BaseAgent
30+
from google.adk.apps.app import App
3031
import google.adk.cli.cli as cli
3132
import pytest
3233

@@ -108,6 +109,28 @@ def __init__(self, name):
108109
return parent_dir, "fake_agent"
109110

110111

112+
@pytest.fixture()
113+
def fake_app_agent(tmp_path: Path):
114+
"""Create an agent package that exposes an App."""
115+
116+
parent_dir = tmp_path / "agents"
117+
parent_dir.mkdir()
118+
agent_dir = parent_dir / "fake_app_agent"
119+
agent_dir.mkdir()
120+
(agent_dir / "__init__.py").write_text(dedent("""
121+
from google.adk.agents.base_agent import BaseAgent
122+
from google.adk.apps.app import App
123+
class FakeAgent(BaseAgent):
124+
def __init__(self, name):
125+
super().__init__(name=name)
126+
127+
root_agent = FakeAgent(name="fake_root")
128+
app = App(name="custom_cli_app", root_agent=root_agent)
129+
"""))
130+
131+
return parent_dir, "fake_app_agent", "custom_cli_app"
132+
133+
111134
# _run_input_file
112135
@pytest.mark.asyncio
113136
async def test_run_input_file_outputs(
@@ -166,6 +189,40 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None:
166189
)
167190

168191

192+
@pytest.mark.asyncio
193+
async def test_run_cli_app_uses_app_name_for_sessions(
194+
fake_app_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
195+
) -> None:
196+
"""run_cli should honor the App-provided name when creating sessions."""
197+
parent_dir, folder_name, app_name = fake_app_agent
198+
created_app_names: List[str] = []
199+
200+
original_session_cls = cli.InMemorySessionService
201+
202+
class _SpySessionService(original_session_cls):
203+
204+
async def create_session(self, *, app_name: str, **kwargs: Any) -> Any:
205+
created_app_names.append(app_name)
206+
return await super().create_session(app_name=app_name, **kwargs)
207+
208+
monkeypatch.setattr(cli, "InMemorySessionService", _SpySessionService)
209+
210+
input_json = {"state": {}, "queries": ["ping"]}
211+
input_path = tmp_path / "input_app.json"
212+
input_path.write_text(json.dumps(input_json))
213+
214+
await cli.run_cli(
215+
agent_parent_dir=str(parent_dir),
216+
agent_folder_name=folder_name,
217+
input_file=str(input_path),
218+
saved_session_file=None,
219+
save_session=False,
220+
)
221+
222+
assert created_app_names
223+
assert all(name == app_name for name in created_app_names)
224+
225+
169226
# _run_cli (interactive + save session branch)
170227
@pytest.mark.asyncio
171228
async def test_run_cli_save_session(

0 commit comments

Comments
 (0)