Skip to content

Commit 592cf4c

Browse files
li-boxuanKevin Chen
authored and
Kevin Chen
committed
Trajectory replay: Fix a few corner cases (All-Hands-AI#6380)
1 parent 959492f commit 592cf4c

File tree

6 files changed

+863
-4
lines changed

6 files changed

+863
-4
lines changed

openhands/controller/agent_controller.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ async def _step(self) -> None:
662662
action = self.agent.step(self.state)
663663
if action is None:
664664
raise LLMNoActionError('No action was returned')
665+
action._source = EventSource.AGENT # type: ignore [attr-defined]
665666
except (
666667
LLMMalformedActionError,
667668
LLMNoActionError,
@@ -720,7 +721,7 @@ async def _step(self) -> None:
720721
== ActionConfirmationStatus.AWAITING_CONFIRMATION
721722
):
722723
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
723-
self.event_stream.add_event(action, EventSource.AGENT)
724+
self.event_stream.add_event(action, action._source) # type: ignore [attr-defined]
724725

725726
await self.update_state_after_step()
726727

openhands/controller/replay.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from openhands.core.logger import openhands_logger as logger
22
from openhands.events.action.action import Action
3+
from openhands.events.action.message import MessageAction
34
from openhands.events.event import Event, EventSource
5+
from openhands.events.observation.empty import NullObservation
46

57

68
class ReplayManager:
@@ -15,9 +17,31 @@ class ReplayManager:
1517
initial state of the trajectory.
1618
"""
1719

18-
def __init__(self, replay_events: list[Event] | None):
20+
def __init__(self, events: list[Event] | None):
21+
replay_events = []
22+
for event in events or []:
23+
if event.source == EventSource.ENVIRONMENT:
24+
# ignore ENVIRONMENT events as they are not issued by
25+
# the user or agent, and should not be replayed
26+
continue
27+
if isinstance(event, NullObservation):
28+
# ignore NullObservation
29+
continue
30+
replay_events.append(event)
31+
1932
if replay_events:
20-
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
33+
logger.info(f'Replay events loaded, events length = {len(replay_events)}')
34+
for index in range(len(replay_events) - 1):
35+
event = replay_events[index]
36+
if isinstance(event, MessageAction) and event.wait_for_response:
37+
# For any message waiting for response that is not the last
38+
# event, we override wait_for_response to False, as a response
39+
# would have been included in the next event, and we don't
40+
# want the user to interfere with the replay process
41+
logger.info(
42+
'Replay events contains wait_for_response message action, ignoring wait_for_response'
43+
)
44+
event.wait_for_response = False
2145
self.replay_events = replay_events
2246
self.replay_mode = bool(replay_events)
2347
self.replay_index = 0
@@ -27,7 +51,6 @@ def _replayable(self) -> bool:
2751
self.replay_events is not None
2852
and self.replay_index < len(self.replay_events)
2953
and isinstance(self.replay_events[self.replay_index], Action)
30-
and self.replay_events[self.replay_index].source != EventSource.USER
3154
)
3255

3356
def should_replay(self) -> bool:

openhands/core/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
231231
events = []
232232
for item in data:
233233
event = event_from_dict(item)
234+
if event.source == EventSource.ENVIRONMENT:
235+
# ignore ENVIRONMENT events as they are not issued by
236+
# the user or agent, and should not be replayed
237+
continue
234238
# cannot add an event with _id to event stream
235239
event._id = None # type: ignore[attr-defined]
236240
events.append(event)

tests/runtime/test_replay.py

+72
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from openhands.core.main import run_controller
1111
from openhands.core.schema.agent import AgentState
1212
from openhands.events.action.empty import NullAction
13+
from openhands.events.action.message import MessageAction
14+
from openhands.events.event import EventSource
1315
from openhands.events.observation.commands import CmdOutputObservation
1416

1517

@@ -46,6 +48,36 @@ def test_simple_replay(temp_dir, runtime_cls, run_as_openhands):
4648
_close_test_runtime(runtime)
4749

4850

51+
def test_simple_gui_replay(temp_dir, runtime_cls, run_as_openhands):
52+
"""
53+
A simple replay test that involves simple terminal operations and edits
54+
(writing a Vue.js App), using the default agent
55+
56+
Note:
57+
1. This trajectory is exported from GUI mode, meaning it has extra
58+
environmental actions that don't appear in headless mode's trajectories
59+
2. In GUI mode, agents typically don't finish; rather, they wait for the next
60+
task from the user, so this exported trajectory ends with awaiting_user_input
61+
"""
62+
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
63+
64+
config = _get_config('basic_gui_mode')
65+
66+
state: State | None = asyncio.run(
67+
run_controller(
68+
config=config,
69+
initial_user_action=NullAction(),
70+
runtime=runtime,
71+
# exit on message, otherwise this would be stuck on waiting for user input
72+
exit_on_message=True,
73+
)
74+
)
75+
76+
assert state.agent_state == AgentState.FINISHED
77+
78+
_close_test_runtime(runtime)
79+
80+
4981
def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
5082
"""
5183
Replay requires a consistent initial state to start with, otherwise it might
@@ -78,3 +110,43 @@ def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
78110
assert has_error_in_action
79111

80112
_close_test_runtime(runtime)
113+
114+
115+
def test_replay_basic_interactions(temp_dir, runtime_cls, run_as_openhands):
116+
"""
117+
Replay a trajectory that involves interactions, i.e. with user messages
118+
in the middle. This tests two things:
119+
1) The controller should be able to replay all actions without human
120+
interference (no asking for user input).
121+
2) The user messages in the trajectory should appear in the history.
122+
"""
123+
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
124+
125+
config = _get_config('basic_interactions')
126+
127+
state: State | None = asyncio.run(
128+
run_controller(
129+
config=config,
130+
initial_user_action=NullAction(),
131+
runtime=runtime,
132+
)
133+
)
134+
135+
assert state.agent_state == AgentState.FINISHED
136+
137+
# all user messages appear in the history, so that after a replay (assuming
138+
# the trajectory doesn't end with `finish` action), LLM knows about all the
139+
# context and can continue
140+
user_messages = [
141+
"what's 1+1?",
142+
"No, I mean by Goldbach's conjecture!",
143+
'Finish please',
144+
]
145+
i = 0
146+
for event in state.history:
147+
if isinstance(event, MessageAction) and event._source == EventSource.USER:
148+
assert event.message == user_messages[i]
149+
i += 1
150+
assert i == len(user_messages)
151+
152+
_close_test_runtime(runtime)

0 commit comments

Comments
 (0)