Skip to content

Commit 8eb1bdb

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Add demo for rewind
Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 833871446
1 parent 236f562 commit 8eb1bdb

File tree

1 file changed

+166
-0
lines changed
  • contributing/samples/rewind_session

1 file changed

+166
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env python3
2+
"""Simple test script for Rewind Session agent."""
3+
4+
# Copyright 2025 Google LLC
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import asyncio
19+
import logging
20+
21+
import agent
22+
from google.adk.agents.run_config import RunConfig
23+
from google.adk.cli.utils import logs
24+
from google.adk.events.event import Event
25+
from google.adk.runners import InMemoryRunner
26+
from google.genai import types
27+
28+
APP_NAME = "rewind_test_app"
29+
USER_ID = "test_user"
30+
31+
logs.setup_adk_logger(level=logging.ERROR)
32+
logging.getLogger("google_genai.types").setLevel(logging.ERROR)
33+
34+
35+
# ANSI color codes for terminal output
36+
COLOR_RED = "\x1b[31m"
37+
COLOR_BLUE = "\x1b[34m"
38+
COLOR_YELLOW = "\x1b[33m"
39+
COLOR_BOLD = "\x1b[1m"
40+
RESET = "\x1b[0m"
41+
42+
43+
def highlight(text: str) -> str:
44+
"""Adds color highlights to tool responses and agent text."""
45+
text = str(text)
46+
return (
47+
text.replace("'red'", f"'{COLOR_RED}red{RESET}'")
48+
.replace('"red"', f'"{COLOR_RED}red{RESET}"')
49+
.replace("'blue'", f"'{COLOR_BLUE}blue{RESET}'")
50+
.replace('"blue"', f'"{COLOR_BLUE}blue{RESET}"')
51+
.replace("'version1'", f"'{COLOR_BOLD}{COLOR_YELLOW}version1{RESET}'")
52+
.replace("'version2'", f"'{COLOR_BOLD}{COLOR_YELLOW}version2{RESET}'")
53+
)
54+
55+
56+
async def call_agent_async(
57+
runner: InMemoryRunner, user_id: str, session_id: str, prompt: str
58+
) -> list[Event]:
59+
"""Helper function to call the agent and return events."""
60+
print(f"\n👤 User: {prompt}")
61+
content = types.Content(
62+
role="user", parts=[types.Part.from_text(text=prompt)]
63+
)
64+
events = []
65+
try:
66+
async for event in runner.run_async(
67+
user_id=user_id,
68+
session_id=session_id,
69+
new_message=content,
70+
run_config=RunConfig(),
71+
):
72+
events.append(event)
73+
if event.content and event.author and event.author != "user":
74+
for part in event.content.parts:
75+
if part.text:
76+
print(f" 🤖 Agent: {highlight(part.text)}")
77+
elif part.function_call:
78+
print(f" 🛠️ Tool Call: {part.function_call.name}")
79+
elif part.function_response:
80+
print(
81+
" 📦 Tool Response:"
82+
f" {highlight(part.function_response.response)}"
83+
)
84+
except Exception as e:
85+
print(f"❌ Error during agent call: {e}")
86+
raise
87+
return events
88+
89+
90+
async def main():
91+
"""Demonstrates session rewind."""
92+
print("🚀 Testing Rewind Session Feature")
93+
print("=" * 50)
94+
95+
runner = InMemoryRunner(
96+
agent=agent.root_agent,
97+
app_name=APP_NAME,
98+
)
99+
100+
# Create a session
101+
session = await runner.session_service.create_session(
102+
app_name=APP_NAME, user_id=USER_ID
103+
)
104+
print(f"Created session: {session.id}")
105+
106+
# 1. Initial agent calls to set state and artifact
107+
print("\n\n===== INITIALIZING STATE AND ARTIFACT =====")
108+
await call_agent_async(
109+
runner, USER_ID, session.id, "set state `color` to red"
110+
)
111+
await call_agent_async(
112+
runner, USER_ID, session.id, "save artifact file1 with content version1"
113+
)
114+
115+
# 2. Check current state and artifact
116+
print("\n\n===== STATE BEFORE UPDATE =====")
117+
await call_agent_async(
118+
runner, USER_ID, session.id, "what is the value of state `color`?"
119+
)
120+
await call_agent_async(runner, USER_ID, session.id, "load artifact file1")
121+
122+
# 3. Update state and artifact - THIS IS THE POINT WE WILL REWIND BEFORE
123+
print("\n\n===== UPDATING STATE AND ARTIFACT =====")
124+
events_update_state = await call_agent_async(
125+
runner, USER_ID, session.id, "update state key color to blue"
126+
)
127+
rewind_invocation_id = events_update_state[0].invocation_id
128+
print(f"Will rewind before invocation: {rewind_invocation_id}")
129+
130+
await call_agent_async(
131+
runner, USER_ID, session.id, "save artifact file1 with content version2"
132+
)
133+
134+
# 4. Check state and artifact after update
135+
print("\n\n===== STATE AFTER UPDATE =====")
136+
await call_agent_async(
137+
runner, USER_ID, session.id, "what is the value of state key color?"
138+
)
139+
await call_agent_async(runner, USER_ID, session.id, "load artifact file1")
140+
141+
# 5. Perform rewind
142+
print(f"\n\n===== REWINDING SESSION to before {rewind_invocation_id} =====")
143+
await runner.rewind_async(
144+
user_id=USER_ID,
145+
session_id=session.id,
146+
rewind_before_invocation_id=rewind_invocation_id,
147+
)
148+
print("✅ Rewind complete.")
149+
150+
# 6. Check state and artifact after rewind
151+
print("\n\n===== STATE AFTER REWIND =====")
152+
await call_agent_async(
153+
runner, USER_ID, session.id, "what is the value of state `color`?"
154+
)
155+
await call_agent_async(runner, USER_ID, session.id, "load artifact file1")
156+
157+
print("\n" + "=" * 50)
158+
print("✨ Rewind testing complete!")
159+
print(
160+
"🔧 If rewind was successful, color should be 'red' and file1 content"
161+
" should contain 'version1' in the final check."
162+
)
163+
164+
165+
if __name__ == "__main__":
166+
asyncio.run(main())

0 commit comments

Comments
 (0)