diff --git a/openhands-sdk/openhands/sdk/context/view.py b/openhands-sdk/openhands/sdk/context/view.py deleted file mode 100644 index c80739e1ac..0000000000 --- a/openhands-sdk/openhands/sdk/context/view.py +++ /dev/null @@ -1,505 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from collections.abc import Sequence -from functools import cached_property -from logging import getLogger -from typing import overload - -from pydantic import BaseModel, computed_field - -from openhands.sdk.event import ( - Condensation, - CondensationRequest, - CondensationSummaryEvent, - LLMConvertibleEvent, -) -from openhands.sdk.event.base import Event, EventID -from openhands.sdk.event.llm_convertible import ( - ActionEvent, - ObservationBaseEvent, -) -from openhands.sdk.event.types import ToolCallID - - -logger = getLogger(__name__) - - -class ActionBatch(BaseModel): - """Represents a batch of ActionEvents grouped by llm_response_id. - - This is a utility class used to help detect and manage batches of ActionEvents - that share the same llm_response_id, which indicates they were generated together - by the LLM. This is important for ensuring atomicity when manipulating events - in a View, such as during condensation. - """ - - batches: dict[EventID, list[EventID]] - """dict mapping llm_response_id to list of ActionEvent IDs""" - - action_id_to_response_id: dict[EventID, EventID] - """dict mapping ActionEvent ID to llm_response_id""" - - action_id_to_tool_call_id: dict[EventID, ToolCallID] - """dict mapping ActionEvent ID to tool_call_id""" - - @staticmethod - def from_events( - events: Sequence[Event], - ) -> ActionBatch: - """Build a map of llm_response_id -> list of ActionEvent IDs.""" - batches: dict[EventID, list[EventID]] = defaultdict(list) - action_id_to_response_id: dict[EventID, EventID] = {} - action_id_to_tool_call_id: dict[EventID, ToolCallID] = {} - - for event in events: - if isinstance(event, ActionEvent): - llm_response_id = event.llm_response_id - batches[llm_response_id].append(event.id) - action_id_to_response_id[event.id] = llm_response_id - if event.tool_call_id is not None: - action_id_to_tool_call_id[event.id] = event.tool_call_id - - return ActionBatch( - batches=batches, - action_id_to_response_id=action_id_to_response_id, - action_id_to_tool_call_id=action_id_to_tool_call_id, - ) - - -class View(BaseModel): - """Linearly ordered view of events. - - Produced by a condenser to indicate the included events are ready to process as LLM - input. Also contains fields with information from the condensation process to aid - in deciding whether further condensation is needed. - """ - - events: list[LLMConvertibleEvent] - - unhandled_condensation_request: bool = False - """Whether there is an unhandled condensation request in the view.""" - - condensations: list[Condensation] = [] - """A list of condensations that were processed to produce the view.""" - - def __len__(self) -> int: - return len(self.events) - - @property - def most_recent_condensation(self) -> Condensation | None: - """Return the most recent condensation, or None if no condensations exist.""" - return self.condensations[-1] if self.condensations else None - - @property - def summary_event_index(self) -> int | None: - """Return the index of the summary event, or None if no summary exists.""" - recent_condensation = self.most_recent_condensation - if ( - recent_condensation is not None - and recent_condensation.summary is not None - and recent_condensation.summary_offset is not None - ): - return recent_condensation.summary_offset - return None - - @property - def summary_event(self) -> CondensationSummaryEvent | None: - """Return the summary event, or None if no summary exists.""" - if self.summary_event_index is not None: - event = self.events[self.summary_event_index] - if isinstance(event, CondensationSummaryEvent): - return event - return None - - @computed_field # type: ignore[prop-decorator] - @cached_property - def manipulation_indices(self) -> list[int]: - """Return cached manipulation indices for this view's events. - - These indices represent boundaries between atomic units where events can be - safely manipulated (inserted or forgotten). An atomic unit is either: - - A tool loop: a sequence of batches starting with thinking blocks and - continuing through all subsequent batches until a non-batch event - - A batch of ActionEvents with the same llm_response_id and their - corresponding ObservationBaseEvents (when not part of a tool loop) - - A single event that is neither an ActionEvent nor an ObservationBaseEvent - - Tool loops are identified by thinking blocks and must remain atomic to - preserve Claude API requirements that the final assistant message must - have thinking blocks when thinking is enabled. - - The returned indices can be used for: - - Inserting new events: any returned index is safe - - Forgetting events: select a range between two consecutive indices - - Consecutive indices define atomic units that must stay together: - - events[indices[i]:indices[i+1]] is an atomic unit - - Returns: - Sorted list of indices representing atomic unit boundaries. Always - includes 0 and len(events) as boundaries. - """ - if not self.events: - return [0] - - # Build mapping of llm_response_id -> list of event indices - batches: dict[EventID, list[int]] = {} - for idx, event in enumerate(self.events): - if isinstance(event, ActionEvent): - llm_response_id = event.llm_response_id - if llm_response_id not in batches: - batches[llm_response_id] = [] - batches[llm_response_id].append(idx) - - # Build mapping of tool_call_id -> observation indices - observation_indices: dict[ToolCallID, int] = {} - for idx, event in enumerate(self.events): - if ( - isinstance(event, ObservationBaseEvent) - and event.tool_call_id is not None - ): - observation_indices[event.tool_call_id] = idx - - # For each batch, find the range of indices that includes all actions - # and their corresponding observations, and track if batch has thinking blocks - batch_ranges: list[tuple[int, int, bool]] = [] - for llm_response_id, action_indices in batches.items(): - min_idx = min(action_indices) - max_idx = max(action_indices) - - # Check if this batch has thinking blocks (only first action has them) - first_action = self.events[min_idx] - has_thinking = ( - isinstance(first_action, ActionEvent) - and len(first_action.thinking_blocks) > 0 - ) - - # Extend the range to include all corresponding observations - for action_idx in action_indices: - action_event = self.events[action_idx] - if ( - isinstance(action_event, ActionEvent) - and action_event.tool_call_id is not None - ): - if action_event.tool_call_id in observation_indices: - obs_idx = observation_indices[action_event.tool_call_id] - max_idx = max(max_idx, obs_idx) - - batch_ranges.append((min_idx, max_idx, has_thinking)) - - # Sort batch ranges by start index for tool loop detection - batch_ranges.sort(key=lambda x: x[0]) - - # Identify tool loops: A tool loop starts with a batch that has thinking - # blocks and continues through all subsequent batches until we hit a - # non-ActionEvent/ObservationEvent (like a user MessageEvent). - tool_loop_ranges: list[tuple[int, int]] = [] - if batch_ranges: - i = 0 - while i < len(batch_ranges): - min_idx, max_idx, has_thinking = batch_ranges[i] - - # If this batch has thinking blocks, start a tool loop - if has_thinking: - loop_start = min_idx - loop_end = max_idx - - # Continue through ALL subsequent batches until we hit - # a non-batch event - j = i + 1 - while j < len(batch_ranges): - next_min, next_max, _ = batch_ranges[j] - - # Check if there's a non-batch event between current - # and next batch - has_non_batch_between = False - for k in range(loop_end + 1, next_min): - event = self.events[k] - if not isinstance( - event, (ActionEvent, ObservationBaseEvent) - ): - has_non_batch_between = True - break - - if has_non_batch_between: - # Tool loop ends before this non-batch event - break - - # Include this batch in the tool loop - loop_end = max(loop_end, next_max) - j += 1 - - tool_loop_ranges.append((loop_start, loop_end)) - i = j - else: - i += 1 - - # Merge batch ranges that are part of tool loops - # Create a mapping of batch index ranges to whether they're in a tool loop - merged_ranges: list[tuple[int, int]] = [] - - if tool_loop_ranges: - # Add tool loop ranges as atomic units - merged_ranges.extend(tool_loop_ranges) - - # Add non-tool-loop batch ranges - tool_loop_indices = set() - for loop_start, loop_end in tool_loop_ranges: - tool_loop_indices.update(range(loop_start, loop_end + 1)) - - for min_idx, max_idx, has_thinking in batch_ranges: - # Only add if not already covered by a tool loop - if min_idx not in tool_loop_indices: - merged_ranges.append((min_idx, max_idx)) - else: - # No tool loops, just use regular batch ranges - merged_ranges = [(min_idx, max_idx) for min_idx, max_idx, _ in batch_ranges] - - # Start with all possible indices (subtractive approach) - result_indices = set(range(len(self.events) + 1)) - - # Remove indices inside merged ranges (keep only boundaries) - for min_idx, max_idx in merged_ranges: - # Remove interior indices, keeping min_idx and max_idx+1 as boundaries - for idx in range(min_idx + 1, max_idx + 1): - result_indices.discard(idx) - - return sorted(result_indices) - - # To preserve list-like indexing, we ideally support slicing and position-based - # indexing. The only challenge with that is switching the return type based on the - # input type -- we can mark the different signatures for MyPy with `@overload` - # decorators. - - @overload - def __getitem__(self, key: slice) -> list[LLMConvertibleEvent]: ... - - @overload - def __getitem__(self, key: int) -> LLMConvertibleEvent: ... - - def __getitem__( - self, key: int | slice - ) -> LLMConvertibleEvent | list[LLMConvertibleEvent]: - if isinstance(key, slice): - start, stop, step = key.indices(len(self)) - return [self[i] for i in range(start, stop, step)] - elif isinstance(key, int): - return self.events[key] - else: - raise ValueError(f"Invalid key type: {type(key)}") - - @staticmethod - def _enforce_batch_atomicity( - events: Sequence[Event], - removed_event_ids: set[EventID], - ) -> set[EventID]: - """Ensure that if any ActionEvent in a batch is removed, all ActionEvents - in that batch are removed. - - This prevents partial batches from being sent to the LLM, which can cause - API errors when thinking blocks are separated from their tool calls. - - Args: - events: The original list of events - removed_event_ids: Set of event IDs that are being removed - - Returns: - Updated set of event IDs that should be removed (including all - ActionEvents in batches where any ActionEvent was removed) - """ - action_batch = ActionBatch.from_events(events) - - if not action_batch.batches: - return removed_event_ids - - updated_removed_ids = set(removed_event_ids) - - for llm_response_id, batch_event_ids in action_batch.batches.items(): - # Check if any ActionEvent in this batch is being removed - if any(event_id in removed_event_ids for event_id in batch_event_ids): - # If so, remove all ActionEvents in this batch - updated_removed_ids.update(batch_event_ids) - logger.debug( - f"Enforcing batch atomicity: removing entire batch " - f"with llm_response_id={llm_response_id} " - f"({len(batch_event_ids)} events)" - ) - - return updated_removed_ids - - @staticmethod - def filter_unmatched_tool_calls( - events: list[LLMConvertibleEvent], - ) -> list[LLMConvertibleEvent]: - """Filter out unmatched tool call events. - - Removes ActionEvents and ObservationEvents that have tool_call_ids - but don't have matching pairs. Also enforces batch atomicity - if any - ActionEvent in a batch is filtered out, all ActionEvents in that batch - are also filtered out. - """ - action_tool_call_ids = View._get_action_tool_call_ids(events) - observation_tool_call_ids = View._get_observation_tool_call_ids(events) - - # Build batch info for batch atomicity enforcement - action_batch = ActionBatch.from_events(events) - - # First pass: identify which events would NOT be kept based on matching - removed_event_ids: set[EventID] = set() - for event in events: - if not View._should_keep_event( - event, action_tool_call_ids, observation_tool_call_ids - ): - removed_event_ids.add(event.id) - - # Second pass: enforce batch atomicity for ActionEvents - # If any ActionEvent in a batch is removed, all ActionEvents in that - # batch should also be removed - removed_event_ids = View._enforce_batch_atomicity(events, removed_event_ids) - - # Third pass: also remove ObservationEvents whose ActionEvents were removed - # due to batch atomicity - tool_call_ids_to_remove: set[ToolCallID] = set() - for action_id in removed_event_ids: - if action_id in action_batch.action_id_to_tool_call_id: - tool_call_ids_to_remove.add( - action_batch.action_id_to_tool_call_id[action_id] - ) - - # Filter out removed events - result = [] - for event in events: - if event.id in removed_event_ids: - continue - if isinstance(event, ObservationBaseEvent): - if event.tool_call_id in tool_call_ids_to_remove: - continue - result.append(event) - - return result - - @staticmethod - def _get_action_tool_call_ids(events: list[LLMConvertibleEvent]) -> set[ToolCallID]: - """Extract tool_call_ids from ActionEvents.""" - tool_call_ids = set() - for event in events: - if isinstance(event, ActionEvent) and event.tool_call_id is not None: - tool_call_ids.add(event.tool_call_id) - return tool_call_ids - - @staticmethod - def _get_observation_tool_call_ids( - events: list[LLMConvertibleEvent], - ) -> set[ToolCallID]: - """Extract tool_call_ids from ObservationEvents.""" - tool_call_ids = set() - for event in events: - if ( - isinstance(event, ObservationBaseEvent) - and event.tool_call_id is not None - ): - tool_call_ids.add(event.tool_call_id) - return tool_call_ids - - @staticmethod - def _should_keep_event( - event: LLMConvertibleEvent, - action_tool_call_ids: set[ToolCallID], - observation_tool_call_ids: set[ToolCallID], - ) -> bool: - """Determine if an event should be kept based on tool call matching.""" - if isinstance(event, ObservationBaseEvent): - return event.tool_call_id in action_tool_call_ids - elif isinstance(event, ActionEvent): - return event.tool_call_id in observation_tool_call_ids - else: - return True - - def find_next_manipulation_index(self, threshold: int, strict: bool = False) -> int: - """Find the smallest manipulation index greater than (or equal to) a threshold. - - This is a helper method for condensation logic that needs to find safe - boundaries for forgetting events. Uses the cached manipulation_indices property. - - Args: - threshold: The threshold value to compare against - strict: If True, finds index > threshold. If False, finds index >= threshold - - Returns: - The smallest manipulation index that satisfies the condition, or the - threshold itself if no such index exists - """ - for idx in self.manipulation_indices: - if strict: - if idx > threshold: - return idx - else: - if idx >= threshold: - return idx - return threshold - - @staticmethod - def from_events(events: Sequence[Event]) -> View: - """Create a view from a list of events, respecting the semantics of any - condensation events. - """ - forgotten_event_ids: set[EventID] = set() - condensations: list[Condensation] = [] - for event in events: - if isinstance(event, Condensation): - condensations.append(event) - forgotten_event_ids.update(event.forgotten_event_ids) - # Make sure we also forget the condensation action itself - forgotten_event_ids.add(event.id) - if isinstance(event, CondensationRequest): - forgotten_event_ids.add(event.id) - - # Enforce batch atomicity: if any event in a multi-action batch is forgotten, - # forget all events in that batch to prevent partial batches with thinking - # blocks separated from their tool calls - forgotten_event_ids = View._enforce_batch_atomicity(events, forgotten_event_ids) - - kept_events = [ - event - for event in events - if event.id not in forgotten_event_ids - and isinstance(event, LLMConvertibleEvent) - ] - - # If we have a summary, insert it at the specified offset. - summary: str | None = None - summary_offset: int | None = None - - # The relevant summary is always in the last condensation event (i.e., the most - # recent one). - for event in reversed(events): - if isinstance(event, Condensation): - if event.summary is not None and event.summary_offset is not None: - summary = event.summary - summary_offset = event.summary_offset - break - - if summary is not None and summary_offset is not None: - logger.debug(f"Inserting summary at offset {summary_offset}") - - _new_summary_event = CondensationSummaryEvent(summary=summary) - kept_events.insert(summary_offset, _new_summary_event) - - # Check for an unhandled condensation request -- these are events closer to the - # end of the list than any condensation action. - unhandled_condensation_request = False - - for event in reversed(events): - if isinstance(event, Condensation): - break - - if isinstance(event, CondensationRequest): - unhandled_condensation_request = True - break - - return View( - events=View.filter_unmatched_tool_calls(kept_events), - unhandled_condensation_request=unhandled_condensation_request, - condensations=condensations, - ) diff --git a/openhands-sdk/openhands/sdk/context/view/README.md b/openhands-sdk/openhands/sdk/context/view/README.md new file mode 100644 index 0000000000..dc70a7da65 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/README.md @@ -0,0 +1,121 @@ +# View + +The `View` class is responsible for representing and manipulating the subset of events that will be provided to the agent's LLM on every step. + +It is closely tied to the context condensation system, and works to ensure the resulting sequence of messages are well-formed and respect the structure expected by common LLM APIs. + +## Architecture Overview + +### Property-Based Design + +The View maintains several **properties** (invariants) that must hold for the event sequence to be valid. Each property has two responsibilities: + +1. **Validation**: Check that the property holds and filter/transform events to enforce it +2. **Manipulation Index Calculation**: Determine "safe boundaries" where events can be inserted or removed without violating the property + +The final set of manipulation indices is computed by taking the **intersection** of the indices from all properties. This ensures that operations at those indices will respect all invariants simultaneously. + +### Why This Matters + +This design provides: +- **Modularity**: Each property is self-contained and independently testable +- **Composability**: New properties can be added without modifying existing ones +- **Clarity**: The interaction between properties is explicit (intersection) +- **Safety**: Manipulation operations are guaranteed to maintain all invariants + +## Properties + +The View maintains four core properties: + +### 1. BatchAtomicityProperty + +**Purpose**: Ensures that ActionEvents sharing the same `llm_response_id` form an atomic unit that cannot be split. + +**Why It Exists**: When an LLM makes a single response containing multiple tool calls, those calls are semantically related. If any one is forgotten (e.g., during condensation), all must be forgotten together to maintain consistency. + +**Validation Logic**: +- Groups ActionEvents by their `llm_response_id` field +- When any ActionEvent in a batch is marked for removal, adds all other ActionEvents from that batch to the removal set +- Uses `ActionBatch.from_events()` to build the mapping + +**Manipulation Index Calculation**: +1. Build mapping: `llm_response_id` → list of ActionEvent indices +2. For each batch, find the min and max indices of all actions +3. Mark the range `[min, max]` as atomic (cannot insert/remove within) +4. Return all indices *outside* these atomic ranges + +**Auxiliary Data**: +- `batches: dict[EventID, list[int]]` - Maps llm_response_id to action indices + +**Example**: +``` +Events: [E0, A1, A2, E3, A4] (A1, A2 share llm_response_id='batch1') +Atomic ranges: [1, 2] +Manipulation indices: {0, 3, 5} (can manipulate before/between/after, not within batch) +``` + +--- + +### 2. ToolLoopAtomicityProperty + +**Purpose**: Ensures that "tool loops" (thinking blocks followed by tool calls) remain atomic units. + +**Why It Exists**: Claude API requires that thinking blocks stay with their associated tool calls. A tool loop is: +- An initial batch containing thinking blocks (ActionEvents with non-empty `thinking_blocks`) +- All subsequent consecutive ActionEvent batches +- Terminated by the first non-ActionEvent/ObservationEvent + +**Validation Logic**: +- Identifies batches that start with thinking blocks +- Extends the atomic unit through all consecutive ActionEvent/ObservationEvent batches +- Does not perform removal (relies on batch atomicity) + +**Manipulation Index Calculation**: +1. Identify batches with thinking blocks (potential tool loop starts) +2. For each such batch, scan forward to find where the tool loop ends (first non-action/observation) +3. Mark entire range as atomic +4. Return all indices *outside* these tool loop ranges + +**Auxiliary Data**: +- `batch_ranges: list[tuple[int, int, bool]]` - (min_idx, max_idx, has_thinking) for each batch +- `tool_loop_ranges: list[tuple[int, int]]` - Start and end indices of tool loops + +**Example**: +``` +Events: [E0, A1(thinking), O1, A2, E3] +Tool loop: [1, 3] (A1 with thinking → O1 → A2, stops at E3) +Manipulation indices: {0, 4, 5} (can only manipulate before loop or after) +``` + +--- + +### 3. ToolCallMatchingProperty + +**Purpose**: Ensures that ActionEvents and ObservationEvents are properly paired via `tool_call_id`. + +**Why It Exists**: LLM APIs expect tool calls to have corresponding observations. Orphaned actions or observations cause API errors. + +**Validation Logic**: +1. Extract all `tool_call_id` values from ActionEvents +2. Extract all `tool_call_id` values from ObservationEvents (includes ObservationEvent, UserRejectObservation, AgentErrorEvent) +3. Keep ActionEvents only if their `tool_call_id` exists in observations +4. Keep ObservationEvents only if their `tool_call_id` exists in actions +5. Keep all other event types unconditionally + +**Manipulation Index Calculation**: +- All indices are valid for this property (no restrictions on boundaries) +- Validation happens through filtering, not boundary restriction +- Returns `set(range(len(events) + 1))` + +**Auxiliary Data**: +- `action_tool_call_ids: set[ToolCallID]` - Tool call IDs from actions +- `observation_tool_call_ids: set[ToolCallID]` - Tool call IDs from observations + +**Example**: +``` +Events: [A1(tc_1), O1(tc_1), A2(tc_2)] +A2 has no matching observation → filtered out +Result: [A1(tc_1), O1(tc_1)] +``` + +--- diff --git a/openhands-sdk/openhands/sdk/context/view/__init__.py b/openhands-sdk/openhands/sdk/context/view/__init__.py new file mode 100644 index 0000000000..ad4f7d8e42 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/__init__.py @@ -0,0 +1,5 @@ +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.context.view.view import View + + +__all__ = ["ManipulationIndices", "View"] diff --git a/openhands-sdk/openhands/sdk/context/view/manipulation_indices.py b/openhands-sdk/openhands/sdk/context/view/manipulation_indices.py new file mode 100644 index 0000000000..791106eb82 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/manipulation_indices.py @@ -0,0 +1,37 @@ +class ManipulationIndices(set[int]): + """A set of indices where events can be safely manipulated. + + This class extends set[int] to provide utility methods for finding + the next valid manipulation index given a threshold. + """ + + def find_next(self, threshold: int, strict: bool = False) -> int: + """Find the smallest manipulation index greater than (or equal to) a threshold. + + This is a helper method for condensation logic that needs to find safe + boundaries for forgetting events. + + Args: + threshold: The threshold value to compare against + strict: If True, finds index > threshold. If False, finds index >= threshold + + Returns: + The smallest manipulation index that satisfies the condition + + Raises: + ValueError: If no valid manipulation index exists that satisfies + the condition + """ + if strict: + valid_indices = [idx for idx in self if idx > threshold] + else: + valid_indices = [idx for idx in self if idx >= threshold] + + if not valid_indices: + operator = ">" if strict else ">=" + raise ValueError( + f"No manipulation index found {operator} {threshold}. " + f"Available indices: {sorted(self)}" + ) + + return min(valid_indices) diff --git a/openhands-sdk/openhands/sdk/context/view/properties/__init__.py b/openhands-sdk/openhands/sdk/context/view/properties/__init__.py new file mode 100644 index 0000000000..4d17075710 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/properties/__init__.py @@ -0,0 +1,18 @@ +from openhands.sdk.context.view.properties.base import ViewPropertyBase +from openhands.sdk.context.view.properties.batch_atomicity import ( + BatchAtomicityProperty, +) +from openhands.sdk.context.view.properties.tool_call_matching import ( + ToolCallMatchingProperty, +) +from openhands.sdk.context.view.properties.tool_loop_atomicity import ( + ToolLoopAtomicityProperty, +) + + +__all__ = [ + "ViewPropertyBase", + "BatchAtomicityProperty", + "ToolCallMatchingProperty", + "ToolLoopAtomicityProperty", +] diff --git a/openhands-sdk/openhands/sdk/context/view/properties/base.py b/openhands-sdk/openhands/sdk/context/view/properties/base.py new file mode 100644 index 0000000000..a7a0da222a --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/properties/base.py @@ -0,0 +1,135 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Sequence + +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.event.base import Event, LLMConvertibleEvent +from openhands.sdk.event.llm_convertible.action import ActionEvent +from openhands.sdk.event.types import EventID + + +class ViewPropertyBase(ABC): + """Abstract base class for properties of a view. + + Properties define rules that help maintain the integrity and coherence of the events + in the view. The properties are maintained via two strategies: + + 1. Enforcing the property by removing events that violate it. + 2. Defining manipulation indices that restrict where the view can be manipulated. + + In an ideal scenario, sticking to the manipulation indices should suffice to ensure + the property holds. Enforcement is only intended as a fallback mechanism to handle + edge cases, bad data, or unforeseen situations. + """ + + @abstractmethod + def enforce( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], + ) -> set[EventID]: + """Enforce the property on a list of events. + + Args: + current_view_events: A list of events currently in the view. + all_events: A list of all Event objects in the conversation. Useful for + properties that need to reference events outside the current view. + + Returns: + A set of EventID objects to be removed from the current view to enforce the + property. + """ + pass + + @abstractmethod + def manipulation_indices( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], + ) -> ManipulationIndices: + """Get manipulation indices for the property on a list of events. + + Args: + current_view_events: A list of events currently in the view. + all_events: A list of all Event objects in the conversation. Useful for + properties that need to reference events outside the current view. + + Returns: + A ManipulationIndices object defining where the view can be manipulated + while maintaining the property. + """ + pass + + @staticmethod + def _build_batches(events: Sequence[Event]) -> dict[EventID, list[EventID]]: + """Build mapping of llm_response_id to ActionEvent IDs. + + Args: + events: Sequence of events to analyze + + Returns: + Dictionary mapping llm_response_id to list of ActionEvent IDs + """ + batches: dict[EventID, list[EventID]] = defaultdict(list) + for event in events: + if isinstance(event, ActionEvent): + batches[event.llm_response_id].append(event.id) + return dict(batches) + + @staticmethod + def _build_event_id_to_index(events: Sequence[Event]) -> dict[EventID, int]: + """Build mapping of event ID to index. + + Args: + events: Sequence of events to analyze + + Returns: + Dictionary mapping event ID to index in the list + """ + return {event.id: idx for idx, event in enumerate(events)} + + @staticmethod + def _get_batch_extent( + action_ids: list[EventID], + event_id_to_index: dict[EventID, int], + ) -> tuple[int, int]: + """Get the min and max indices for a batch of action IDs. + + Args: + action_ids: List of ActionEvent IDs in the batch + event_id_to_index: Mapping of event IDs to indices + + Returns: + Tuple of (min_index, max_index) for the batch + """ + indices = [event_id_to_index[aid] for aid in action_ids] + return min(indices), max(indices) + + @staticmethod + def _build_manipulation_indices_from_atomic_ranges( + atomic_ranges: list[tuple[int, int]], + num_events: int, + ) -> ManipulationIndices: + """Build ManipulationIndices that exclude indices within atomic ranges. + + Atomic ranges represent contiguous sequences of events that must remain + together. This method creates a set of valid manipulation indices that + excludes all indices that fall within these atomic ranges. + + Args: + atomic_ranges: List of (start_idx, end_idx) tuples defining atomic ranges + num_events: Total number of events in the view + + Returns: + ManipulationIndices with valid indices (excluding atomic ranges) + """ + # Start with all possible indices (including after the last event) + valid_indices = set(range(num_events + 1)) + + # Remove indices that fall within atomic ranges + for start_idx, end_idx in atomic_ranges: + # Cannot insert/remove within the atomic range (exclusive of start boundary) + for idx in range(start_idx + 1, end_idx + 1): + valid_indices.discard(idx) + + return ManipulationIndices(valid_indices) diff --git a/openhands-sdk/openhands/sdk/context/view/properties/batch_atomicity.py b/openhands-sdk/openhands/sdk/context/view/properties/batch_atomicity.py new file mode 100644 index 0000000000..f5cdd39b67 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/properties/batch_atomicity.py @@ -0,0 +1,106 @@ +"""Property for ensuring ActionEvent batches remain atomic.""" + +from collections.abc import Sequence + +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.context.view.properties.base import ViewPropertyBase +from openhands.sdk.event.base import Event, LLMConvertibleEvent +from openhands.sdk.event.types import EventID + + +class BatchAtomicityProperty(ViewPropertyBase): + """Ensures ActionEvents sharing the same llm_response_id form an atomic unit. + + When an LLM makes a single response containing multiple tool calls, those calls + are semantically related. If any one is forgotten (e.g., during condensation), + all must be forgotten together to maintain consistency. + """ + + def enforce( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], + ) -> set[EventID]: + """Enforce batch atomicity by marking all events in a partially-removed batch. + + If any ActionEvent in a batch is missing, this method will mark all other + ActionEvents from that batch for removal. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + Set of EventIDs to remove from the current view + """ + # Build mappings from all events to understand complete batches + all_batches = self._build_batches(all_events) + view_batches = self._build_batches(current_view_events) + + events_to_remove: set[EventID] = set() + + # Check each batch in the original events + for llm_response_id, action_ids in all_batches.items(): + # Get which actions from this batch are in the view + actions_in_view = view_batches.get(llm_response_id, []) + + # If batch is partially present (some but not all actions) + if actions_in_view and len(actions_in_view) < len(action_ids): + # Remove all actions from this batch from the view + events_to_remove.update(actions_in_view) + + return events_to_remove + + def manipulation_indices( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], # noqa: ARG002 + ) -> ManipulationIndices: + """Calculate manipulation indices that respect batch atomicity. + + Returns all indices outside of batch ranges. Within a batch (from min to max + index), no manipulation is allowed. The range includes all actions in the batch + and their corresponding observations. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + ManipulationIndices with all valid manipulation points + """ + from openhands.sdk.event.llm_convertible import ( + ActionEvent, + ObservationBaseEvent, + ) + + batches = self._build_batches(current_view_events) + event_id_to_index = self._build_event_id_to_index(current_view_events) + + # Build tool_call_id to observation mapping + tool_call_to_obs_idx: dict[str, int] = {} + for idx, event in enumerate(current_view_events): + if isinstance(event, ObservationBaseEvent) and event.tool_call_id: + tool_call_to_obs_idx[event.tool_call_id] = idx + + # Find atomic ranges for each batch + atomic_ranges: list[tuple[int, int]] = [] + + for llm_response_id, action_ids in batches.items(): + # Get indices for all actions in this batch + min_idx, max_idx = self._get_batch_extent(action_ids, event_id_to_index) + + # Extend max_idx to include all corresponding observations + for action_id in action_ids: + action_event = current_view_events[event_id_to_index[action_id]] + if isinstance(action_event, ActionEvent) and action_event.tool_call_id: + obs_idx = tool_call_to_obs_idx.get(action_event.tool_call_id) + if obs_idx is not None: + max_idx = max(max_idx, obs_idx) + + atomic_ranges.append((min_idx, max_idx)) + + # Build manipulation indices that exclude atomic ranges + return self._build_manipulation_indices_from_atomic_ranges( + atomic_ranges, len(current_view_events) + ) diff --git a/openhands-sdk/openhands/sdk/context/view/properties/tool_call_matching.py b/openhands-sdk/openhands/sdk/context/view/properties/tool_call_matching.py new file mode 100644 index 0000000000..f824b745a5 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/properties/tool_call_matching.py @@ -0,0 +1,109 @@ +"""Property for ensuring ActionEvents and ObservationEvents are properly paired.""" + +from collections.abc import Sequence + +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.context.view.properties.base import ViewPropertyBase +from openhands.sdk.event.base import Event, LLMConvertibleEvent +from openhands.sdk.event.llm_convertible.action import ActionEvent +from openhands.sdk.event.llm_convertible.observation import ObservationBaseEvent +from openhands.sdk.event.types import EventID, ToolCallID + + +class ToolCallMatchingProperty(ViewPropertyBase): + """Ensures ActionEvents and ObservationEvents are properly paired via tool_call_id. + + LLM APIs expect tool calls to have corresponding observations. Orphaned actions + or observations cause API errors. + """ + + @staticmethod + def _extract_action_tool_call_ids(events: Sequence[Event]) -> set[ToolCallID]: + """Extract all tool_call_ids from ActionEvents. + + Args: + events: Sequence of events to analyze + + Returns: + Set of tool_call_ids from ActionEvents + """ + tool_call_ids: set[ToolCallID] = set() + for event in events: + if isinstance(event, ActionEvent) and event.tool_call_id is not None: + tool_call_ids.add(event.tool_call_id) + return tool_call_ids + + @staticmethod + def _extract_observation_tool_call_ids( + events: Sequence[Event], + ) -> set[ToolCallID]: + """Extract all tool_call_ids from ObservationBaseEvents. + + Args: + events: Sequence of events to analyze + + Returns: + Set of tool_call_ids from ObservationBaseEvents + """ + tool_call_ids: set[ToolCallID] = set() + for event in events: + if ( + isinstance(event, ObservationBaseEvent) + and event.tool_call_id is not None + ): + tool_call_ids.add(event.tool_call_id) + return tool_call_ids + + def enforce( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], # noqa: ARG002 + ) -> set[EventID]: + """Enforce tool call matching by removing orphaned actions and observations. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + Set of EventIDs to remove from the current view + """ + action_tool_call_ids = self._extract_action_tool_call_ids(current_view_events) + observation_tool_call_ids = self._extract_observation_tool_call_ids( + current_view_events + ) + + events_to_remove: set[EventID] = set() + + # Remove ActionEvents without matching observations + for event in current_view_events: + if isinstance(event, ActionEvent): + if event.tool_call_id not in observation_tool_call_ids: + events_to_remove.add(event.id) + + # Remove ObservationEvents without matching actions + elif isinstance(event, ObservationBaseEvent): + if event.tool_call_id not in action_tool_call_ids: + events_to_remove.add(event.id) + + return events_to_remove + + def manipulation_indices( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], # noqa: ARG002 + ) -> ManipulationIndices: + """Calculate manipulation indices for tool call matching. + + All indices are valid for this property. Validation happens through + filtering in the enforce method, not through boundary restriction. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + ManipulationIndices with all indices valid + """ + # All indices are valid - filtering is done via enforce() + return ManipulationIndices(set(range(len(current_view_events) + 1))) diff --git a/openhands-sdk/openhands/sdk/context/view/properties/tool_loop_atomicity.py b/openhands-sdk/openhands/sdk/context/view/properties/tool_loop_atomicity.py new file mode 100644 index 0000000000..0ed3c497ff --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/properties/tool_loop_atomicity.py @@ -0,0 +1,246 @@ +"""Property for ensuring tool loops remain atomic.""" + +from collections.abc import Sequence + +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.context.view.properties.base import ViewPropertyBase +from openhands.sdk.event.base import Event, LLMConvertibleEvent +from openhands.sdk.event.llm_convertible.action import ActionEvent +from openhands.sdk.event.llm_convertible.observation import ObservationBaseEvent +from openhands.sdk.event.types import EventID + + +class ToolLoopAtomicityProperty(ViewPropertyBase): + """Ensures that tool loops (thinking blocks + tool calls) remain atomic units. + + Claude API requires that thinking blocks stay with their associated tool calls. + A tool loop is: + - An initial batch with thinking blocks (ActionEvents w/ non-empty thinking_blocks) + - All subsequent consecutive ActionEvent/ObservationEvent batches + - Terminated by the first non-ActionEvent/ObservationEvent + """ + + def _build_batch_ranges( + self, + batches: dict[EventID, list[EventID]], + events: Sequence[Event], + event_id_to_index: dict[EventID, int], + ) -> list[tuple[int, int, bool, list[EventID]]]: + """Build batch range metadata for tool loop detection. + + Args: + batches: Mapping of llm_response_id to action event IDs + events: Event sequence to analyze + event_id_to_index: Mapping of event IDs to their indices + + Returns: + List of tuples (min_idx, max_idx, has_thinking, action_ids) by min_idx + """ + batch_ranges: list[tuple[int, int, bool, list[EventID]]] = [] + + for llm_response_id, action_ids in batches.items(): + # Get indices for all actions in this batch + min_idx, max_idx = self._get_batch_extent(action_ids, event_id_to_index) + + # Check if any action in this batch has thinking blocks + has_thinking = False + for action_id in action_ids: + idx = event_id_to_index[action_id] + event = events[idx] + if isinstance(event, ActionEvent) and event.thinking_blocks: + has_thinking = True + break + + batch_ranges.append((min_idx, max_idx, has_thinking, action_ids)) + + # Sort batch ranges by min_idx + batch_ranges.sort(key=lambda x: x[0]) + return batch_ranges + + def _scan_tool_loop_extent( + self, + start_idx: int, + batch_ranges: list[tuple[int, int, bool, list[EventID]]], + events: Sequence[Event], + ) -> tuple[int, int, int]: + """Scan forward from a starting batch to find the full extent of a tool loop. + + Args: + start_idx: Index in batch_ranges where the tool loop starts + (must have has_thinking=True) + batch_ranges: Sorted list of batch range tuples + events: Event sequence being analyzed + + Returns: + Tuple of (loop_start_event_idx, loop_end_event_idx, next_batch_idx) + - loop_start_event_idx: Index of first event in the tool loop + - loop_end_event_idx: Index of last event in the tool loop + - next_batch_idx: Index in batch_ranges after this loop ends + """ + min_idx, max_idx, has_thinking, _ = batch_ranges[start_idx] + + if not has_thinking: + raise ValueError( + "Tool loop must start with a batch containing thinking blocks" + ) + + loop_start = min_idx + loop_end = max_idx + + # Scan forward through consecutive action/observation batches + j = start_idx + 1 + while j < len(batch_ranges): + next_min, next_max, _, _ = batch_ranges[j] + + # Check if there are only ActionEvents/ObservationEvents between + # current loop_end and next_min + all_action_or_obs = True + for idx in range(loop_end + 1, next_min): + event = events[idx] + if not isinstance(event, (ActionEvent, ObservationBaseEvent)): + all_action_or_obs = False + break + + if all_action_or_obs: + # Extend the tool loop + loop_end = next_max + j += 1 + else: + # Tool loop ends here + break + + # Scan forward to include any trailing observations + scan_idx = loop_end + 1 + while scan_idx < len(events): + event = events[scan_idx] + if isinstance(event, ObservationBaseEvent): + loop_end = scan_idx + scan_idx += 1 + elif isinstance(event, ActionEvent): + # Another action - should have been caught by batch processing above + break + else: + # Non-action/observation terminates the loop + break + + return loop_start, loop_end, j + + def _identify_tool_loops(self, events: Sequence[Event]) -> list[list[EventID]]: + """Identify all tool loops in the event sequence. + + Returns: + List of tool loops, where each tool loop is a list of EventIDs + """ + batches = self._build_batches(events) + event_id_to_index = self._build_event_id_to_index(events) + + # Build batch ranges with metadata using helper + batch_ranges = self._build_batch_ranges(batches, events, event_id_to_index) + + # Identify tool loops + tool_loops: list[list[EventID]] = [] + + i = 0 + while i < len(batch_ranges): + _, _, has_thinking, action_ids = batch_ranges[i] + + if has_thinking: + # Use helper to find the full extent of this tool loop + loop_start, loop_end, next_i = self._scan_tool_loop_extent( + i, batch_ranges, events + ) + + # Collect all event IDs within the loop range + loop_event_ids: list[EventID] = [] + for idx in range(loop_start, loop_end + 1): + loop_event_ids.append(events[idx].id) + + tool_loops.append(loop_event_ids) + i = next_i + else: + i += 1 + + return tool_loops + + def enforce( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], + ) -> set[EventID]: + """Enforce tool loop atomicity by removing partially-present tool loops. + + If a tool loop is partially present in the view, all events from that + tool loop are removed. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + Set of EventIDs to remove from the current view + """ + # Identify all tool loops in the complete conversation + tool_loops = self._identify_tool_loops(all_events) + + # Build set of event IDs currently in view + view_event_ids = {event.id for event in current_view_events} + + events_to_remove: set[EventID] = set() + + # Check each tool loop + for loop_event_ids in tool_loops: + # Count how many events from this loop are in the view + events_in_view = [eid for eid in loop_event_ids if eid in view_event_ids] + + # If loop is partially present (some but not all events) + if events_in_view and len(events_in_view) < len(loop_event_ids): + # Remove all events from this loop that are in the view + events_to_remove.update(events_in_view) + + return events_to_remove + + def manipulation_indices( + self, + current_view_events: Sequence[LLMConvertibleEvent], + all_events: Sequence[Event], # noqa: ARG002 + ) -> ManipulationIndices: + """Calculate manipulation indices that respect tool loop atomicity. + + Returns all indices outside of tool loop ranges. + + Args: + current_view_events: Events currently in the view + all_events: All events in the conversation + + Returns: + ManipulationIndices with all valid manipulation points + """ + batches = self._build_batches(current_view_events) + event_id_to_index = self._build_event_id_to_index(current_view_events) + + # Build batch ranges with metadata using helper + batch_ranges = self._build_batch_ranges( + batches, current_view_events, event_id_to_index + ) + + # Identify tool loop ranges + tool_loop_ranges: list[tuple[int, int]] = [] + + i = 0 + while i < len(batch_ranges): + _, _, has_thinking, _ = batch_ranges[i] + + if has_thinking: + # Use helper to find the full extent of this tool loop + loop_start, loop_end, next_i = self._scan_tool_loop_extent( + i, batch_ranges, current_view_events + ) + tool_loop_ranges.append((loop_start, loop_end)) + i = next_i + else: + i += 1 + + # Build manipulation indices that exclude tool loop ranges + return self._build_manipulation_indices_from_atomic_ranges( + tool_loop_ranges, len(current_view_events) + ) diff --git a/openhands-sdk/openhands/sdk/context/view/view.py b/openhands-sdk/openhands/sdk/context/view/view.py new file mode 100644 index 0000000000..1f5249c64e --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/view/view.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from collections.abc import Sequence +from logging import getLogger +from typing import overload + +from pydantic import BaseModel, Field + +from openhands.sdk.context.view.manipulation_indices import ManipulationIndices +from openhands.sdk.context.view.properties.batch_atomicity import ( + BatchAtomicityProperty, +) +from openhands.sdk.context.view.properties.tool_call_matching import ( + ToolCallMatchingProperty, +) +from openhands.sdk.context.view.properties.tool_loop_atomicity import ( + ToolLoopAtomicityProperty, +) +from openhands.sdk.event import ( + Condensation, + CondensationRequest, + CondensationSummaryEvent, + LLMConvertibleEvent, +) +from openhands.sdk.event.base import Event, EventID + + +logger = getLogger(__name__) + + +class View(BaseModel): + """Linearly ordered view of events. + + Produced by a condenser to indicate the included events are ready to process as LLM + input. Also contains fields with information from the condensation process to aid + in deciding whether further condensation is needed. + """ + + model_config = {"arbitrary_types_allowed": True} + + events: list[LLMConvertibleEvent] + + unhandled_condensation_request: bool = False + """Whether there is an unhandled condensation request in the view.""" + + condensations: list[Condensation] = [] + """A list of condensations that were processed to produce the view.""" + + def __len__(self) -> int: + return len(self.events) + + @property + def most_recent_condensation(self) -> Condensation | None: + """Return the most recent condensation, or None if no condensations exist.""" + return self.condensations[-1] if self.condensations else None + + @property + def summary_event_index(self) -> int | None: + """Return the index of the summary event, or None if no summary exists.""" + recent_condensation = self.most_recent_condensation + if ( + recent_condensation is not None + and recent_condensation.summary is not None + and recent_condensation.summary_offset is not None + ): + return recent_condensation.summary_offset + return None + + @property + def summary_event(self) -> CondensationSummaryEvent | None: + """Return the summary event, or None if no summary exists.""" + if self.summary_event_index is not None: + event = self.events[self.summary_event_index] + if isinstance(event, CondensationSummaryEvent): + return event + return None + + manipulation_indices: ManipulationIndices = Field( + description=("Manipulation indices for this view's events. ") + ) + + # To preserve list-like indexing, we ideally support slicing and position-based + # indexing. The only challenge with that is switching the return type based on the + # input type -- we can mark the different signatures for MyPy with `@overload` + # decorators. + + @overload + def __getitem__(self, key: slice) -> list[LLMConvertibleEvent]: ... + + @overload + def __getitem__(self, key: int) -> LLMConvertibleEvent: ... + + def __getitem__( + self, key: int | slice + ) -> LLMConvertibleEvent | list[LLMConvertibleEvent]: + if isinstance(key, slice): + start, stop, step = key.indices(len(self)) + return [self[i] for i in range(start, stop, step)] + elif isinstance(key, int): + return self.events[key] + else: + raise ValueError(f"Invalid key type: {type(key)}") + + def find_next_manipulation_index(self, threshold: int, strict: bool = False) -> int: + """Find the smallest manipulation index greater than (or equal to) a threshold. + + This is a helper method for condensation logic that needs to find safe + boundaries for forgetting events. Uses the cached manipulation_indices property. + + Args: + threshold: The threshold value to compare against + strict: If True, finds index > threshold. If False, finds index >= threshold + + Returns: + The smallest manipulation index that satisfies the condition, or the + threshold itself if no such index exists + """ + return self.manipulation_indices.find_next(threshold, strict) + + @staticmethod + def _enforce_properties( + view_events: list[LLMConvertibleEvent], + all_events: Sequence[Event], + properties: list, + ) -> list[LLMConvertibleEvent]: + """Enforce properties iteratively until no violations remain. + + Properties are checked in order, and we restart from the first property + whenever any property removes events (to handle cascading effects). + + Args: + view_events: Initial list of events in the view + all_events: Complete list of all events in the conversation + properties: List of property instances to enforce + + Returns: + Filtered list of events with all property violations resolved + """ + current_events = view_events + + while True: + events_removed_this_iteration: set[EventID] = set() + + for prop in properties: + events_to_remove = prop.enforce(current_events, all_events) + if events_to_remove: + logger.debug( + f"{prop.__class__.__name__} removing " + f"{len(events_to_remove)} events" + ) + events_removed_this_iteration.update(events_to_remove) + # Exit inner loop and restart from first property + break + + if not events_removed_this_iteration: + # No events removed by any property - enforcement complete + break + + # Remove events and continue iterating + current_events = [ + e for e in current_events if e.id not in events_removed_this_iteration + ] + + return current_events + + @staticmethod + def _calculate_manipulation_indices( + view_events: list[LLMConvertibleEvent], + all_events: Sequence[Event], + properties: list, + ) -> ManipulationIndices: + """Calculate manipulation indices by intersecting all property indices. + + Args: + view_events: Events in the view + all_events: Complete list of all events in the conversation + properties: List of property instances + + Returns: + ManipulationIndices representing safe boundaries for event manipulation + """ + if not view_events: + return ManipulationIndices({0}) + + # Get manipulation indices from each property and intersect them + all_indices = [ + prop.manipulation_indices(view_events, all_events) for prop in properties + ] + return ManipulationIndices( + set.intersection(*all_indices) if all_indices else set() + ) + + @staticmethod + def _unhandled_condensation_request( + events: Sequence[Event], + ) -> bool: + """Check for an unhandled condensation request in the event list.""" + for event in reversed(events): + if isinstance(event, Condensation): + return False + if isinstance(event, CondensationRequest): + return True + return False + + @staticmethod + def _apply_condensations( + events: Sequence[Event], + ) -> tuple[list[LLMConvertibleEvent], list[Condensation]]: + """Apply condensations to the event list, removing forgotten events.""" + forgotten_event_ids: set[EventID] = set() + condensations: list[Condensation] = [] + for event in events: + if isinstance(event, Condensation): + condensations.append(event) + forgotten_event_ids.update(event.forgotten_event_ids) + # Make sure we also forget the condensation action itself + forgotten_event_ids.add(event.id) + if isinstance(event, CondensationRequest): + forgotten_event_ids.add(event.id) + + kept_events = [ + event + for event in events + if event.id not in forgotten_event_ids + and isinstance(event, LLMConvertibleEvent) + ] + + # If we have a summary, insert it at the specified offset. + summary: str | None = None + summary_offset: int | None = None + + # The relevant summary is always in the last condensation event (i.e., the most + # recent one). + for event in reversed(events): + if isinstance(event, Condensation): + if event.summary is not None and event.summary_offset is not None: + summary = event.summary + summary_offset = event.summary_offset + break + + if summary is not None and summary_offset is not None: + logger.debug(f"Inserting summary at offset {summary_offset}") + + _new_summary_event = CondensationSummaryEvent(summary=summary) + kept_events.insert(summary_offset, _new_summary_event) + return kept_events, condensations + + @staticmethod + def from_events(events: Sequence[Event]) -> View: + """Create a view from a list of events, respecting the semantics of any + condensation events. + """ + kept_events, condensations = View._apply_condensations(events) + + # Check for an unhandled condensation request -- these are events closer to the + # end of the list than any condensation action. + unhandled_condensation_request = View._unhandled_condensation_request(events) + + # Define view properties for enforcement and manipulation indices + properties = [ + ToolCallMatchingProperty(), + BatchAtomicityProperty(), + ToolLoopAtomicityProperty(), + ] + + # Apply property enforcement to remove violations + view_events = View._enforce_properties(kept_events, events, properties) + + # Calculate manipulation_indices by taking intersection of all properties + manipulation_indices = View._calculate_manipulation_indices( + view_events, events, properties + ) + + return View( + events=view_events, + unhandled_condensation_request=unhandled_condensation_request, + condensations=condensations, + manipulation_indices=manipulation_indices, + ) diff --git a/tests/integration/tests/c01_thinking_block_condenser.py b/tests/integration/tests/c01_thinking_block_condenser.py index 396a7955f9..beb550adc3 100644 --- a/tests/integration/tests/c01_thinking_block_condenser.py +++ b/tests/integration/tests/c01_thinking_block_condenser.py @@ -49,7 +49,7 @@ def condense(self, view: View, agent_llm: LLM | None = None) -> View | Condensat 3. Later thinking blocks are preserved """ # Get manipulation indices which define boundaries of atomic units - indices = view.manipulation_indices + indices = sorted(view.manipulation_indices) # Find atomic units (ranges between consecutive indices) with thinking blocks units_with_thinking = [] diff --git a/tests/sdk/context/test_tool_loop_boundaries.py b/tests/sdk/context/test_tool_loop_boundaries.py index 84dc4f8a50..0327937752 100644 --- a/tests/sdk/context/test_tool_loop_boundaries.py +++ b/tests/sdk/context/test_tool_loop_boundaries.py @@ -91,11 +91,11 @@ def test_single_batch_with_thinking(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 3] + # Should have boundaries: {0, 1, 3} # - 0: before user message # - 1: before tool loop (action + observation) # - 3: after tool loop - assert indices == [0, 1, 3] + assert indices == {0, 1, 3} def test_tool_loop_multiple_batches(): @@ -126,12 +126,12 @@ def test_tool_loop_multiple_batches(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 7, 8] + # Should have boundaries: {0, 1, 7, 8} # - 0: before first user message # - 1: before tool loop (all 3 batches are one atomic unit) # - 7: after tool loop, before second user message # - 8: after second user message - assert indices == [0, 1, 7, 8] + assert indices == {0, 1, 7, 8} def test_tool_loop_ends_at_non_batch_event(): @@ -161,13 +161,13 @@ def test_tool_loop_ends_at_non_batch_event(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 5, 6, 8] + # Should have boundaries: {0, 1, 5, 6, 8} # - 0: before first user message # - 1: before first tool loop (batches 1-2) # - 5: after first tool loop, before second user message # - 6: after second user message, before second tool loop # - 8: after second tool loop - assert indices == [0, 1, 5, 6, 8] + assert indices == {0, 1, 5, 6, 8} def test_batch_without_thinking_not_a_tool_loop(): @@ -185,13 +185,13 @@ def test_batch_without_thinking_not_a_tool_loop(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 3, 5] + # Should have boundaries: {0, 1, 3, 5} # Each batch is separate since no thinking blocks # - 0: before user message # - 1: before first batch # - 3: after first batch, before second batch # - 5: after second batch - assert indices == [0, 1, 3, 5] + assert indices == {0, 1, 3, 5} def test_multiple_separate_tool_loops(): @@ -221,14 +221,14 @@ def test_multiple_separate_tool_loops(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 5, 6, 8, 9] + # Should have boundaries: {0, 1, 5, 6, 8, 9} # - 0: before user 1 # - 1: before first tool loop # - 5: after first tool loop, before user 2 # - 6: after user 2, before second tool loop # - 8: after second tool loop, before user 3 # - 9: after user 3 - assert indices == [0, 1, 5, 6, 8, 9] + assert indices == {0, 1, 5, 6, 8, 9} def test_parallel_tool_calls_in_tool_loop(): @@ -257,19 +257,19 @@ def test_parallel_tool_calls_in_tool_loop(): view = View.from_events(events) indices = view.manipulation_indices - # Should have boundaries: [0, 1, 7, 8] + # Should have boundaries: {0, 1, 7, 8} # - 0: before user message # - 1: before tool loop (includes both batches) # - 7: after tool loop, before next user message # - 8: after next user message - assert indices == [0, 1, 7, 8] + assert indices == {0, 1, 7, 8} def test_empty_events(): """Test manipulation indices with empty events list.""" view = View.from_events([]) indices = view.manipulation_indices - assert indices == [0] + assert indices == {0} def test_only_user_messages(): @@ -286,4 +286,4 @@ def test_only_user_messages(): # - 0: before first message # - 1: after first message, before second message # - 2: after second message - assert indices == [0, 1, 2] + assert indices == {0, 1, 2} diff --git a/tests/sdk/context/test_view_action_filtering.py b/tests/sdk/context/test_view_action_filtering.py deleted file mode 100644 index d7da64c934..0000000000 --- a/tests/sdk/context/test_view_action_filtering.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Test for confirmation mode issue with condenser view filtering. - -This test reproduces the issue where ActionEvents are incorrectly filtered out -when paired with UserRejectObservation or AgentErrorEvent instead of ObservationEvent. -""" - -from unittest.mock import create_autospec - -from openhands.sdk.context.view import View -from openhands.sdk.event.llm_convertible import ( - ActionEvent, - AgentErrorEvent, - MessageEvent, - ObservationEvent, - UserRejectObservation, -) -from openhands.sdk.llm import Message, TextContent - - -def message_event(content: str) -> MessageEvent: - """Helper to create a MessageEvent.""" - return MessageEvent( - llm_message=Message(role="user", content=[TextContent(text=content)]), - source="user", - ) - - -def test_filter_unmatched_tool_calls_with_user_reject_observation() -> None: - """Test that ActionEvent paired with UserRejectObservation is not filtered out. - - This reproduces the confirmation mode issue where user rejection causes - ActionEvents to be incorrectly filtered out by the condenser. - """ - # Create a mock ActionEvent with tool_call_id - action_event = create_autospec(ActionEvent, instance=True) - action_event.tool_call_id = "call_1" - action_event.id = "action_1" - action_event.llm_response_id = "response_1" - - # Create a UserRejectObservation that responds to the action - user_reject_obs = UserRejectObservation( - action_id="action_1", - tool_name="TerminalTool", - tool_call_id="call_1", - rejection_reason="User rejected the action", - ) - - # Create some other events - message1 = message_event("First message") - message2 = message_event("Second message") - - events = [ - message1, - action_event, - user_reject_obs, - message2, - ] - - # Filter the events - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Both the ActionEvent and UserRejectObservation should be kept - # because they form a matched pair (after the fix) - assert len(result) == 4 - assert action_event in result - assert user_reject_obs in result - assert message1 in result - assert message2 in result - - -def test_filter_unmatched_tool_calls_with_agent_error_event() -> None: - """Test that ActionEvent paired with AgentErrorEvent is not filtered out. - - This tests the case where an agent error occurs during tool execution. - """ - # Create a mock ActionEvent with tool_call_id - action_event = create_autospec(ActionEvent, instance=True) - action_event.tool_call_id = "call_1" - action_event.id = "action_1" - action_event.llm_response_id = "response_1" - - # Create an AgentErrorEvent that responds to the action - # After the fix, AgentErrorEvent should have tool_name and tool_call_id fields - agent_error = AgentErrorEvent( - error="Tool execution failed", - tool_name="TerminalTool", - tool_call_id="call_1", - ) - - # Create some other events - message1 = message_event("First message") - message2 = message_event("Second message") - - events = [ - message1, - action_event, - agent_error, - message2, - ] - - # Filter the events - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Both the ActionEvent and AgentErrorEvent should be kept - # because they form a matched pair (after the fix) - assert len(result) == 4 - assert action_event in result - assert agent_error in result - assert message1 in result - assert message2 in result - - -def test_filter_unmatched_tool_calls_mixed_observation_types() -> None: - """Test filtering with mixed observation types. - - This tests a scenario with normal ObservationEvent, UserRejectObservation, - and AgentErrorEvent to ensure proper filtering behavior. - """ - # Create ActionEvents - action_event_1 = create_autospec(ActionEvent, instance=True) - action_event_1.tool_call_id = "call_1" - action_event_1.id = "action_1" - action_event_1.llm_response_id = "response_1" - - action_event_2 = create_autospec(ActionEvent, instance=True) - action_event_2.tool_call_id = "call_2" - action_event_2.id = "action_2" - action_event_2.llm_response_id = "response_2" - - action_event_3 = create_autospec(ActionEvent, instance=True) - action_event_3.tool_call_id = "call_3" - action_event_3.id = "action_3" - action_event_3.llm_response_id = "response_3" - - # Create different types of observations - # Normal observation - should work - observation_event = create_autospec(ObservationEvent, instance=True) - observation_event.tool_call_id = "call_1" - observation_event.id = "obs_1" - - # User rejection - should work after fix - user_reject_obs = UserRejectObservation( - action_id="action_2", - tool_name="TerminalTool", - tool_call_id="call_2", - rejection_reason="User rejected the action", - ) - - # Agent error - should work after fix (but not before) - agent_error = AgentErrorEvent( - error="Tool execution failed", - tool_name="TerminalTool", - tool_call_id="call_3", - ) - - events = [ - message_event("Start"), - action_event_1, - observation_event, - action_event_2, - user_reject_obs, - action_event_3, - agent_error, - message_event("End"), - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # After fix: all matched pairs should be kept - # action_event_1 paired with observation_event - # action_event_2 paired with user_reject_obs - # action_event_3 paired with agent_error - - # After the fix, all action events should be kept - # because all observation types are now recognized - assert len(result) == 8 # All events kept - assert action_event_1 in result - assert observation_event in result - assert action_event_2 in result # Fixed! - assert user_reject_obs in result - assert action_event_3 in result # Fixed! - assert agent_error in result diff --git a/tests/sdk/context/test_view_non_exec_filtering.py b/tests/sdk/context/test_view_non_exec_filtering.py deleted file mode 100644 index 928f7d4808..0000000000 --- a/tests/sdk/context/test_view_non_exec_filtering.py +++ /dev/null @@ -1,57 +0,0 @@ -import json - -from openhands.sdk.context.view import View -from openhands.sdk.event.llm_convertible import ( - ActionEvent, - AgentErrorEvent, - MessageEvent, -) -from openhands.sdk.llm import Message, MessageToolCall, TextContent - - -def test_filter_keeps_action_none_when_matched_by_observation() -> None: - """Test that ActionEvent with action=None is kept when matched by observation.""" - # ActionEvent with action=None and a tool_call id - tc = MessageToolCall( - id="call_keep_me", - name="missing_tool", - arguments=json.dumps({}), - origin="completion", - ) - action_event = ActionEvent( - source="agent", - thought=[TextContent(text="...")], - tool_call=tc, - tool_name=tc.name, - tool_call_id=tc.id, - llm_response_id="resp_view_1", - action=None, - ) - - # Matching AgentErrorEvent (observation path) - err = AgentErrorEvent( - source="agent", - error="not found", - tool_name="missing_tool", - tool_call_id="call_keep_me", - ) - - # Noise message events - m1 = MessageEvent( - source="user", - llm_message=Message(role="user", content=[TextContent(text="hi")]), - ) - m2 = MessageEvent( - source="user", - llm_message=Message(role="user", content=[TextContent(text="bye")]), - ) - - events = [m1, action_event, err, m2] - - filtered = View.filter_unmatched_tool_calls(events) # type: ignore[arg-type] - - # Both ActionEvent(action=None) and matching AgentErrorEvent must be kept - assert len(filtered) == 4 - assert action_event in filtered - assert err in filtered - assert m1 in filtered and m2 in filtered diff --git a/tests/sdk/context/view/properties/__init__.py b/tests/sdk/context/view/properties/__init__.py new file mode 100644 index 0000000000..d4652c2af8 --- /dev/null +++ b/tests/sdk/context/view/properties/__init__.py @@ -0,0 +1,5 @@ +"""Tests for View properties. + +This package contains tests for individual property classes that enforce +invariants on the View event sequence. +""" diff --git a/tests/sdk/context/view/properties/test_batch_atomicity.py b/tests/sdk/context/view/properties/test_batch_atomicity.py new file mode 100644 index 0000000000..dc5cb64d47 --- /dev/null +++ b/tests/sdk/context/view/properties/test_batch_atomicity.py @@ -0,0 +1,374 @@ +"""Tests for BatchAtomicityProperty. + +This module tests the BatchAtomicityProperty class independently from the View class. +The property ensures that ActionEvents sharing the same llm_response_id form an atomic +unit that cannot be split during condensation. + +Note: View-level integration tests for batch atomicity also exist in +tests/sdk/context/view/test_view_batch_atomicity.py. These tests will eventually +be removed once we're satisfied with the property-level tests. +""" + +from openhands.sdk.context.view.properties.batch_atomicity import ( + BatchAtomicityProperty, +) +from openhands.sdk.event.llm_convertible import ( + ActionEvent, + MessageEvent, + ObservationEvent, +) +from openhands.sdk.llm import ( + Message, + MessageToolCall, + TextContent, + ThinkingBlock, +) +from openhands.sdk.mcp.definition import MCPToolAction, MCPToolObservation + + +def create_action_event( + llm_response_id: str, + tool_call_id: str, + tool_name: str = "test_tool", + thinking_blocks: list[ThinkingBlock] | None = None, +) -> ActionEvent: + """Helper to create an ActionEvent with specified IDs.""" + action = MCPToolAction(data={}) + + tool_call = MessageToolCall( + id=tool_call_id, + name=tool_name, + arguments="{}", + origin="completion", + ) + + return ActionEvent( + thought=[TextContent(text="Test thought")], + thinking_blocks=thinking_blocks or [], # type: ignore + action=action, + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_call=tool_call, + llm_response_id=llm_response_id, + source="agent", + ) + + +def create_observation_event( + tool_call_id: str, content: str = "Success", tool_name: str = "test_tool" +) -> ObservationEvent: + """Helper to create an ObservationEvent.""" + observation = MCPToolObservation.from_text( + text=content, + tool_name=tool_name, + ) + return ObservationEvent( + observation=observation, + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id="action_event_id", + source="environment", + ) + + +def message_event(content: str) -> MessageEvent: + """Helper to create a MessageEvent.""" + return MessageEvent( + llm_message=Message(role="user", content=[TextContent(text=content)]), + source="user", + ) + + +# ============================================================================ +# Tests for enforce() method +# ============================================================================ + + +def test_enforce_no_removal_when_all_actions_present() -> None: + """Test that no events are removed when all actions in a batch are present.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + current_view = [action1, action2, action3] + all_events = [action1, action2, action3] + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_removes_partial_batch() -> None: + """Test that partial batches are completely removed.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + # All events exist + all_events = [action1, action2, action3] + + # But view only has some of them + current_view = [action1, action3] # Missing action2 + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should remove all actions from the partial batch + assert action1.id in to_remove + assert action3.id in to_remove + + +def test_enforce_single_action_batch_not_affected() -> None: + """Test that single-action batches are not affected by enforcement.""" + action = create_action_event("response_1", "call_1") + + current_view = [action] + all_events = [action] + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_multiple_batches_only_removes_partial() -> None: + """Test that only partial batches are removed, not complete ones.""" + # Batch 1: complete in view + batch1_action1 = create_action_event("response_1", "call_1") + batch1_action2 = create_action_event("response_1", "call_2") + + # Batch 2: partial in view + batch2_action1 = create_action_event("response_2", "call_3") + batch2_action2 = create_action_event("response_2", "call_4") + + all_events = [batch1_action1, batch1_action2, batch2_action1, batch2_action2] + + # View has all of batch1 but only part of batch2 + current_view = [batch1_action1, batch1_action2, batch2_action1] + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should only remove batch2_action1 (from the partial batch) + assert batch1_action1.id not in to_remove + assert batch1_action2.id not in to_remove + assert batch2_action1.id in to_remove + + +def test_enforce_with_non_action_events() -> None: + """Test that non-action events don't interfere with batch atomicity.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + msg = message_event("Test message") + + all_events = [action1, action2, msg] + current_view = [action1, msg] # Missing action2 + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should remove action1 from the partial batch + assert action1.id in to_remove + # Message should not be affected + assert msg.id not in to_remove + + +def test_enforce_empty_view() -> None: + """Test enforce with empty view.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + + all_events = [action1, action2] + current_view = [] + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_with_thinking_blocks() -> None: + """Test that batches with thinking blocks are handled correctly.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Extended thinking", signature="sig1") + ] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + all_events = [action1, action2, action3] + current_view = [action1, action2] # Missing action3 + + prop = BatchAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should remove all present actions from partial batch + assert action1.id in to_remove + assert action2.id in to_remove + + +# ============================================================================ +# Tests for manipulation_indices() method +# ============================================================================ + + +def test_manipulation_indices_empty_events() -> None: + """Test manipulation indices with no events.""" + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices([], []) + + # With no events, only index 0 is valid + assert indices == {0} + + +def test_manipulation_indices_single_action() -> None: + """Test manipulation indices with a single action.""" + action = create_action_event("response_1", "call_1") + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices([action], [action]) + + # Single action batch allows manipulation before and after + assert indices == {0, 1} + + +def test_manipulation_indices_multi_action_batch() -> None: + """Test manipulation indices with multi-action batch.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + events = [action1, action2, action3] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Can only manipulate at boundaries: before first action and after last + # Cannot manipulate between action1-action2 or action2-action3 + assert indices == {0, 3} + + +def test_manipulation_indices_interleaved_batch() -> None: + """Test manipulation indices with batch actions interleaved with observations.""" + action1 = create_action_event("response_1", "call_1") + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_1", "call_2") + obs2 = create_observation_event("call_2") + + events = [action1, obs1, action2, obs2] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Batch spans from index 0 (action1) to index 3 (obs2, last observation) + # The batch includes observations, so range extends to the last observation + # Can't manipulate at indices 1, 2, 3 (within batch range w/ observations) + # Can manipulate at 0 (before), 4 (end) + assert indices == {0, 4} + + +def test_manipulation_indices_multiple_batches() -> None: + """Test manipulation indices with multiple separate batches.""" + # Batch 1 + action1_1 = create_action_event("response_1", "call_1") + action1_2 = create_action_event("response_1", "call_2") + + # Batch 2 + action2_1 = create_action_event("response_2", "call_3") + action2_2 = create_action_event("response_2", "call_4") + + events = [action1_1, action1_2, action2_1, action2_2] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Batch 1: indices 0-1, Batch 2: indices 2-3 + # Can manipulate at: 0 (before batch1), 2 (between batches), 4 (after batch2) + assert indices == {0, 2, 4} + + +def test_manipulation_indices_batches_with_messages() -> None: + """Test manipulation indices with messages between batches.""" + msg1 = message_event("Start") + + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + + msg2 = message_event("Middle") + + action3 = create_action_event("response_2", "call_3") + + events = [msg1, action1, action2, msg2, action3] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # msg1 at 0, batch1 at 1-2, msg2 at 3, action3 at 4 + # Can manipulate at: 0, 1, 3, 4, 5 + # Cannot manipulate at: 2 (within batch1) + assert indices == {0, 1, 3, 4, 5} + + +def test_manipulation_indices_non_consecutive_batch() -> None: + """Test manipulation indices when batch actions are non-consecutive.""" + action1 = create_action_event("response_1", "call_1") + msg = message_event("Between") + action2 = create_action_event("response_1", "call_2") + + events = [action1, msg, action2] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Batch spans from index 0 to 2, so can't manipulate at 1 or 2 + assert indices == {0, 3} + + +def test_manipulation_indices_only_single_action_batches() -> None: + """Test that single-action batches without observations don't restrict indices.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_2", "call_2") + action3 = create_action_event("response_3", "call_3") + + events = [action1, action2, action3] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # All single-action batches without observations, so can manipulate anywhere + # Each batch is just a single action with no observation to extend to + assert indices == {0, 1, 2, 3} + + +def test_manipulation_indices_complex_scenario() -> None: + """Test complex scenario with multiple batches and event types.""" + msg1 = message_event("Start") + + # Batch 1: 3 actions + batch1_a1 = create_action_event("response_1", "call_1") + batch1_a2 = create_action_event("response_1", "call_2") + batch1_a3 = create_action_event("response_1", "call_3") + + obs1 = create_observation_event("call_1") + + msg2 = message_event("Middle") + + # Batch 2: 2 actions + batch2_a1 = create_action_event("response_2", "call_4") + batch2_a2 = create_action_event("response_2", "call_5") + + events = [msg1, batch1_a1, batch1_a2, batch1_a3, obs1, msg2, batch2_a1, batch2_a2] + + prop = BatchAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # msg1: 0 + # batch1: 1-4 (includes actions 1-3 and obs1 at index 4) + # msg2: 5 + # batch2: 6-7 (no observations, so just actions) + # end: 8 + # Can manipulate at: 0, 1 (before batch1), 5 (between batches), + # 6 (before batch2), 8 (end) + assert indices == {0, 1, 5, 6, 8} diff --git a/tests/sdk/context/view/properties/test_tool_call_matching.py b/tests/sdk/context/view/properties/test_tool_call_matching.py new file mode 100644 index 0000000000..4d4b4f352d --- /dev/null +++ b/tests/sdk/context/view/properties/test_tool_call_matching.py @@ -0,0 +1,411 @@ +"""Tests for ToolCallMatchingProperty. + +This module tests the ToolCallMatchingProperty class independently from the View class. +The property ensures that ActionEvents and ObservationEvents are properly paired via +tool_call_id. Orphaned actions or observations cause API errors and must be removed. +""" + +from openhands.sdk.context.view.properties.tool_call_matching import ( + ToolCallMatchingProperty, +) +from openhands.sdk.event.llm_convertible import ( + ActionEvent, + AgentErrorEvent, + MessageEvent, + ObservationEvent, + UserRejectObservation, +) +from openhands.sdk.llm import ( + Message, + MessageToolCall, + TextContent, +) +from openhands.sdk.mcp.definition import MCPToolAction, MCPToolObservation + + +def create_action_event( + llm_response_id: str, + tool_call_id: str, + tool_name: str = "test_tool", +) -> ActionEvent: + """Helper to create an ActionEvent with specified IDs.""" + action = MCPToolAction(data={}) + + tool_call = MessageToolCall( + id=tool_call_id, + name=tool_name, + arguments="{}", + origin="completion", + ) + + return ActionEvent( + thought=[TextContent(text="Test thought")], + thinking_blocks=[], + action=action, + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_call=tool_call, + llm_response_id=llm_response_id, + source="agent", + ) + + +def create_observation_event( + tool_call_id: str, content: str = "Success", tool_name: str = "test_tool" +) -> ObservationEvent: + """Helper to create an ObservationEvent.""" + observation = MCPToolObservation.from_text( + text=content, + tool_name=tool_name, + ) + return ObservationEvent( + observation=observation, + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id="action_event_id", + source="environment", + ) + + +def create_user_reject_observation( + tool_call_id: str, tool_name: str = "test_tool" +) -> UserRejectObservation: + """Helper to create a UserRejectObservation.""" + return UserRejectObservation( + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id="action_event_id", + rejection_reason="User rejected", + source="environment", + ) + + +def create_agent_error_event( + tool_call_id: str, tool_name: str = "test_tool" +) -> AgentErrorEvent: + """Helper to create an AgentErrorEvent.""" + return AgentErrorEvent( + tool_name=tool_name, + tool_call_id=tool_call_id, + error="Test error", + source="agent", + ) + + +def message_event(content: str) -> MessageEvent: + """Helper to create a MessageEvent.""" + return MessageEvent( + llm_message=Message(role="user", content=[TextContent(text=content)]), + source="user", + ) + + +# ============================================================================ +# Tests for enforce() method +# ============================================================================ + + +def test_enforce_matched_pairs_no_removal() -> None: + """Test that matched action-observation pairs are not removed.""" + action = create_action_event("response_1", "call_1") + obs = create_observation_event("call_1") + + current_view = [action, obs] + all_events = [action, obs] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_removes_orphaned_action() -> None: + """Test that actions without matching observations are removed.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + obs1 = create_observation_event("call_1") + + current_view = [action1, action2, obs1] + all_events = [action1, action2, obs1] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # action2 has no matching observation + assert action2.id in to_remove + assert action1.id not in to_remove + assert obs1.id not in to_remove + + +def test_enforce_removes_orphaned_observation() -> None: + """Test that observations without matching actions are removed.""" + action = create_action_event("response_1", "call_1") + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_2") # No matching action + + current_view = [action, obs1, obs2] + all_events = [action, obs1, obs2] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # obs2 has no matching action + assert obs2.id in to_remove + assert action.id not in to_remove + assert obs1.id not in to_remove + + +def test_enforce_removes_both_orphaned_actions_and_observations() -> None: + """Test that both orphaned actions and observations are removed.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") # Orphaned + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_3") # Orphaned + + current_view = [action1, action2, obs1, obs2] + all_events = [action1, action2, obs1, obs2] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + assert action2.id in to_remove + assert obs2.id in to_remove + assert action1.id not in to_remove + assert obs1.id not in to_remove + + +def test_enforce_user_reject_observation_counts_as_match() -> None: + """Test that UserRejectObservation matches with ActionEvent.""" + action = create_action_event("response_1", "call_1") + reject = create_user_reject_observation("call_1") + + current_view = [action, reject] + all_events = [action, reject] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # Both should be kept + assert len(to_remove) == 0 + + +def test_enforce_agent_error_event_counts_as_match() -> None: + """Test that AgentErrorEvent matches with ActionEvent.""" + action = create_action_event("response_1", "call_1") + error = create_agent_error_event("call_1") + + current_view = [action, error] + all_events = [action, error] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # Both should be kept + assert len(to_remove) == 0 + + +def test_enforce_multiple_observation_types() -> None: + """Test with mix of ObservationEvent, UserRejectObservation, and AgentErrorEvent.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + obs1 = create_observation_event("call_1") + reject2 = create_user_reject_observation("call_2") + error3 = create_agent_error_event("call_3") + + current_view = [action1, action2, action3, obs1, reject2, error3] + all_events = [action1, action2, action3, obs1, reject2, error3] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # All matched, nothing to remove + assert len(to_remove) == 0 + + +def test_enforce_empty_view() -> None: + """Test enforce with empty view.""" + prop = ToolCallMatchingProperty() + to_remove = prop.enforce([], []) + + assert len(to_remove) == 0 + + +def test_enforce_only_messages() -> None: + """Test that messages are not affected by tool call matching.""" + msg1 = message_event("Message 1") + msg2 = message_event("Message 2") + + current_view = [msg1, msg2] + all_events = [msg1, msg2] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_mixed_with_messages() -> None: + """Test that messages are preserved while orphaned events are removed.""" + msg1 = message_event("Start") + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") # Orphaned + obs1 = create_observation_event("call_1") + msg2 = message_event("End") + + current_view = [msg1, action1, action2, obs1, msg2] + all_events = [msg1, action1, action2, obs1, msg2] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + assert action2.id in to_remove + assert msg1.id not in to_remove + assert msg2.id not in to_remove + + +def test_enforce_cascading_removal() -> None: + """Test that removing actions can cascade to their observations and vice versa. + + Note: This property doesn't do cascading - each element is independently checked. + Cascading would require multiple passes or composition with other properties. + """ + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_2") + + # View has action1 and obs2, but missing their pairs + current_view = [action1, obs2] + all_events = [action1, action2, obs1, obs2] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + # Both should be removed as orphans + assert action1.id in to_remove # No obs1 in view + assert obs2.id in to_remove # No action2 in view + + +def test_enforce_same_tool_call_id_different_events() -> None: + """Test that matching works even with same tool_call_id on different events.""" + action = create_action_event("response_1", "call_1") + obs = create_observation_event("call_1") + + current_view = [action, obs] + all_events = [action, obs] + + prop = ToolCallMatchingProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +# ============================================================================ +# Tests for manipulation_indices() method +# ============================================================================ + + +def test_manipulation_indices_all_valid() -> None: + """Test that all indices are valid for tool call matching property. + + Unlike batch atomicity and tool loop atomicity, this property doesn't + restrict manipulation indices. It validates through filtering instead. + """ + action = create_action_event("response_1", "call_1") + obs = create_observation_event("call_1") + + events = [action, obs] + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices(events, events) + + # All indices should be valid + assert indices == {0, 1, 2} + + +def test_manipulation_indices_empty_events() -> None: + """Test with empty event list.""" + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices([], []) + + assert indices == {0} + + +def test_manipulation_indices_complex_scenario() -> None: + """Test that all indices are valid regardless of event complexity.""" + msg1 = message_event("Start") + + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_2") + + msg2 = message_event("End") + + events = [msg1, action1, action2, obs1, obs2, msg2] + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices(events, events) + + # All indices are valid + assert indices == {0, 1, 2, 3, 4, 5, 6} + + +def test_manipulation_indices_orphaned_events() -> None: + """Test that orphaned events don't affect manipulation indices.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") # Orphaned + obs1 = create_observation_event("call_1") + + events = [action1, action2, obs1] + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices(events, events) + + # All indices are still valid + assert indices == {0, 1, 2, 3} + + +def test_manipulation_indices_only_messages() -> None: + """Test with only message events.""" + msg1 = message_event("Message 1") + msg2 = message_event("Message 2") + msg3 = message_event("Message 3") + + events = [msg1, msg2, msg3] + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices(events, events) + + assert indices == {0, 1, 2, 3} + + +def test_manipulation_indices_with_different_observation_types() -> None: + """Test that different observation types don't affect indices.""" + action1 = create_action_event("response_1", "call_1") + action2 = create_action_event("response_1", "call_2") + action3 = create_action_event("response_1", "call_3") + + obs = create_observation_event("call_1") + reject = create_user_reject_observation("call_2") + error = create_agent_error_event("call_3") + + events = [action1, action2, action3, obs, reject, error] + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices(events, events) + + # All indices valid + assert indices == {0, 1, 2, 3, 4, 5, 6} + + +def test_manipulation_indices_single_event() -> None: + """Test with a single event.""" + action = create_action_event("response_1", "call_1") + + prop = ToolCallMatchingProperty() + indices = prop.manipulation_indices([action], [action]) + + assert indices == {0, 1} diff --git a/tests/sdk/context/view/properties/test_tool_loop_atomicity.py b/tests/sdk/context/view/properties/test_tool_loop_atomicity.py new file mode 100644 index 0000000000..6e197a5fba --- /dev/null +++ b/tests/sdk/context/view/properties/test_tool_loop_atomicity.py @@ -0,0 +1,488 @@ +"""Tests for ToolLoopAtomicityProperty. + +This module tests the ToolLoopAtomicityProperty class independently from the View class. +The property ensures that tool loops (thinking blocks + consecutive tool calls) remain +atomic units that cannot be split. + +A tool loop consists of: +- An initial batch containing thinking blocks +- All subsequent consecutive ActionEvent/ObservationEvent batches +- Terminated by the first non-ActionEvent/ObservationEvent +""" + +from collections.abc import Sequence + +from openhands.sdk.context.view.properties.tool_loop_atomicity import ( + ToolLoopAtomicityProperty, +) +from openhands.sdk.event.llm_convertible import ( + ActionEvent, + MessageEvent, + ObservationEvent, +) +from openhands.sdk.llm import ( + Message, + MessageToolCall, + RedactedThinkingBlock, + TextContent, + ThinkingBlock, +) +from openhands.sdk.mcp.definition import MCPToolAction, MCPToolObservation + + +def create_action_event( + llm_response_id: str, + tool_call_id: str, + tool_name: str = "test_tool", + thinking_blocks: Sequence[ThinkingBlock | RedactedThinkingBlock] | None = None, +) -> ActionEvent: + """Helper to create an ActionEvent with specified IDs.""" + action = MCPToolAction(data={}) + + tool_call = MessageToolCall( + id=tool_call_id, + name=tool_name, + arguments="{}", + origin="completion", + ) + + return ActionEvent( + thought=[TextContent(text="Test thought")], + thinking_blocks=list(thinking_blocks) if thinking_blocks else [], + action=action, + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_call=tool_call, + llm_response_id=llm_response_id, + source="agent", + ) + + +def create_observation_event( + tool_call_id: str, content: str = "Success", tool_name: str = "test_tool" +) -> ObservationEvent: + """Helper to create an ObservationEvent.""" + observation = MCPToolObservation.from_text( + text=content, + tool_name=tool_name, + ) + return ObservationEvent( + observation=observation, + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id="action_event_id", + source="environment", + ) + + +def message_event(content: str) -> MessageEvent: + """Helper to create a MessageEvent.""" + return MessageEvent( + llm_message=Message(role="user", content=[TextContent(text=content)]), + source="user", + ) + + +# ============================================================================ +# Tests for enforce() method +# ============================================================================ + + +def test_enforce_complete_tool_loop_no_removal() -> None: + """Test that complete tool loops are not removed.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + all_events = [action1, obs1, action2, obs2] + current_view = [action1, obs1, action2, obs2] + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_partial_tool_loop_removed() -> None: + """Test that partial tool loops are completely removed.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + # Complete tool loop in all_events + all_events = [action1, obs1, action2, obs2] + + # But view is missing action2 (partial loop) + current_view = [action1, obs1, obs2] + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should remove all events from the partial tool loop + assert action1.id in to_remove + assert obs1.id in to_remove + assert obs2.id in to_remove + + +def test_enforce_no_thinking_blocks_no_enforcement() -> None: + """Test that actions without thinking blocks don't trigger tool loop enforcement.""" + action1 = create_action_event("response_1", "call_1") + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + all_events = [action1, obs1, action2, obs2] + current_view = [action1, obs1] # Partial view, but no thinking blocks + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # No tool loop, so no enforcement + assert len(to_remove) == 0 + + +def test_enforce_tool_loop_terminated_by_message() -> None: + """Test that tool loops are correctly terminated by non-action/observation.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + msg = message_event("User message") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + all_events = [action1, obs1, msg, action2, obs2] + + # View has first part of loop but not msg + current_view = [action1, obs1, action2, obs2] + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Tool loop is just action1+obs1 (terminated by msg) + # action2/obs2 are separate, so if we're missing msg, the loop is still complete + # Actually, looking at the all_events, the loop is action1+obs1, terminated by msg + # current_view has action1+obs1 which is the complete loop, so nothing to remove + assert len(to_remove) == 0 + + +def test_enforce_multiple_batches_in_tool_loop() -> None: + """Test tool loop spanning multiple batches.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + # First batch with thinking + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + action2 = create_action_event("response_1", "call_2") + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_2") + + # Second batch (extends the loop) + action3 = create_action_event("response_2", "call_3") + obs3 = create_observation_event("call_3") + + all_events = [action1, action2, obs1, obs2, action3, obs3] + + # View is missing obs3 (partial loop) + current_view = [action1, action2, obs1, obs2, action3] + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # Should remove all events from the partial tool loop + assert action1.id in to_remove + assert action2.id in to_remove + assert obs1.id in to_remove + assert obs2.id in to_remove + assert action3.id in to_remove + + +def test_enforce_empty_view() -> None: + """Test enforce with empty view.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs = create_observation_event("call_1") + + all_events = [action, obs] + current_view = [] + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + assert len(to_remove) == 0 + + +def test_enforce_redacted_thinking_blocks() -> None: + """Test that redacted thinking blocks also trigger tool loop logic.""" + thinking = [RedactedThinkingBlock(type="redacted_thinking", data="redacted")] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + all_events = [action1, obs1, action2, obs2] + current_view = [action1, obs1] # Missing the continuation + + prop = ToolLoopAtomicityProperty() + to_remove = prop.enforce(current_view, all_events) + + # action1 has thinking, action2 is consecutive (all action/obs between) + # So the full loop is action1, obs1, action2, obs2 + # current_view only has action1, obs1, so it's partial + assert len(to_remove) == 2 + + +# ============================================================================ +# Tests for manipulation_indices() method +# ============================================================================ + + +def test_manipulation_indices_no_thinking_blocks() -> None: + """Test that without thinking blocks, all indices are valid.""" + action1 = create_action_event("response_1", "call_1") + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + events = [action1, obs1, action2, obs2] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # No tool loops, so all indices are valid + assert indices == {0, 1, 2, 3, 4} + + +def test_manipulation_indices_simple_tool_loop() -> None: + """Test manipulation indices with a simple tool loop.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs = create_observation_event("call_1") + + events = [action, obs] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Tool loop spans indices 0-1, can only manipulate at boundaries + assert indices == {0, 2} + + +def test_manipulation_indices_tool_loop_with_continuation() -> None: + """Test manipulation indices when tool loop continues across batches.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + # Batch 1 with thinking + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + + # Batch 2 (consecutive, extends the loop) + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + events = [action1, obs1, action2, obs2] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Entire sequence is one tool loop + assert indices == {0, 4} + + +def test_manipulation_indices_tool_loop_terminated_by_message() -> None: + """Test that messages terminate tool loops.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + obs1 = create_observation_event("call_1") + msg = message_event("User message") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + events = [action1, obs1, msg, action2, obs2] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Tool loop is action1+obs1 (indices 0-1), terminated by msg at 2 + # action2+obs2 form a separate batch (no thinking blocks) + # Can manipulate at: 0, 2, 3, 4, 5 + # Cannot manipulate at: 1 (within tool loop) + assert indices == {0, 2, 3, 4, 5} + + +def test_manipulation_indices_multiple_tool_loops() -> None: + """Test multiple separate tool loops.""" + thinking1 = [ + ThinkingBlock(type="thinking", thinking="Thinking 1", signature="sig1") + ] + thinking2 = [ + ThinkingBlock(type="thinking", thinking="Thinking 2", signature="sig2") + ] + + # First tool loop + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking1) + obs1 = create_observation_event("call_1") + + msg = message_event("Between loops") + + # Second tool loop + action2 = create_action_event("response_2", "call_2", thinking_blocks=thinking2) + obs2 = create_observation_event("call_2") + + events = [action1, obs1, msg, action2, obs2] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Loop1: 0-1, msg: 2, Loop2: 3-4 + # Can manipulate at: 0, 2, 3, 5 + # Cannot manipulate at: 1 (within loop1), 4 (within loop2) + assert indices == {0, 2, 3, 5} + + +def test_manipulation_indices_multi_batch_tool_loop() -> None: + """Test tool loop spanning multiple action batches.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + # Batch 1 with thinking (2 actions) + action1_1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + action1_2 = create_action_event("response_1", "call_2") + obs1_1 = create_observation_event("call_1") + obs1_2 = create_observation_event("call_2") + + # Batch 2 (consecutive, extends loop) + action2 = create_action_event("response_2", "call_3") + obs2 = create_observation_event("call_3") + + # Batch 3 (consecutive, extends loop) + action3 = create_action_event("response_3", "call_4") + obs3 = create_observation_event("call_4") + + msg = message_event("End") + + events = [ + action1_1, + action1_2, + obs1_1, + obs1_2, + action2, + obs2, + action3, + obs3, + msg, + ] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # Entire tool loop spans 0-7, terminated by msg at 8 + # Can manipulate at: 0, 8, 9 + assert indices == {0, 8, 9} + + +def test_manipulation_indices_empty_events() -> None: + """Test with empty event list.""" + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices([], []) + + assert indices == {0} + + +def test_manipulation_indices_single_message() -> None: + """Test with single message event.""" + msg = message_event("Test") + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices([msg], [msg]) + + # No tool loops, all indices valid + assert indices == {0, 1} + + +def test_manipulation_indices_interleaved_observations() -> None: + """Test tool loop with observations interleaved between actions.""" + thinking = [ + ThinkingBlock(type="thinking", thinking="Thinking...", signature="sig1") + ] + + # Batch with thinking + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking) + action2 = create_action_event("response_1", "call_2") + + obs1 = create_observation_event("call_1") + + # Another batch (consecutive) + action3 = create_action_event("response_2", "call_3") + + obs2 = create_observation_event("call_2") + obs3 = create_observation_event("call_3") + + events = [action1, action2, obs1, action3, obs2, obs3] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # All are actions/observations following thinking batch, so one big loop + assert indices == {0, 6} + + +def test_manipulation_indices_complex_scenario() -> None: + """Test complex scenario with multiple loops and event types.""" + thinking1 = [ThinkingBlock(type="thinking", thinking="First", signature="sig1")] + + msg1 = message_event("Start") + + # Tool loop 1 + action1 = create_action_event("response_1", "call_1", thinking_blocks=thinking1) + obs1 = create_observation_event("call_1") + action2 = create_action_event("response_2", "call_2") + obs2 = create_observation_event("call_2") + + msg2 = message_event("Middle") + + # Regular action without thinking + action3 = create_action_event("response_3", "call_3") + obs3 = create_observation_event("call_3") + + msg3 = message_event("End") + + events = [msg1, action1, obs1, action2, obs2, msg2, action3, obs3, msg3] + + prop = ToolLoopAtomicityProperty() + indices = prop.manipulation_indices(events, events) + + # msg1: 0 + # Loop1: 1-4 (action1, obs1, action2, obs2) + # msg2: 5 + # action3/obs3: 6-7 (no thinking, not a loop) + # msg3: 8 + # Can manipulate at: 0, 1, 5, 6, 7, 8, 9 + # Cannot manipulate at: 2, 3, 4 (within loop1) + assert indices == {0, 1, 5, 6, 7, 8, 9} diff --git a/tests/sdk/context/view/properties/test_view_property_base.py b/tests/sdk/context/view/properties/test_view_property_base.py new file mode 100644 index 0000000000..6c679d9f4b --- /dev/null +++ b/tests/sdk/context/view/properties/test_view_property_base.py @@ -0,0 +1,267 @@ +"""Tests for ViewPropertyBase utility functions.""" + +from openhands.sdk.context.view.properties.base import ViewPropertyBase +from openhands.sdk.event.llm_convertible import ( + ActionEvent, + MessageEvent, + ObservationEvent, +) +from openhands.sdk.llm import Message, MessageToolCall, TextContent +from openhands.sdk.mcp.definition import MCPToolAction, MCPToolObservation + + +def create_action_event( + llm_response_id: str, + tool_call_id: str, + tool_name: str = "test_tool", +) -> ActionEvent: + """Helper to create an ActionEvent.""" + action = MCPToolAction(data={}) + tool_call = MessageToolCall( + id=tool_call_id, + name=tool_name, + arguments="{}", + origin="completion", + ) + + return ActionEvent( + thought=[TextContent(text="Test thought")], + thinking_blocks=[], + action=action, + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_call=tool_call, + llm_response_id=llm_response_id, + source="agent", + ) + + +def create_observation_event( + tool_call_id: str, content: str = "Success", tool_name: str = "test_tool" +) -> ObservationEvent: + """Helper to create an ObservationEvent.""" + observation = MCPToolObservation.from_text( + text=content, + tool_name=tool_name, + ) + return ObservationEvent( + observation=observation, + tool_name=tool_name, + tool_call_id=tool_call_id, + action_id="action_event_id", + source="environment", + ) + + +def message_event(content: str) -> MessageEvent: + """Helper to create a MessageEvent.""" + return MessageEvent( + llm_message=Message(role="user", content=[TextContent(text=content)]), + source="user", + ) + + +# ============================================================================ +# Tests for _build_batches() utility function +# ============================================================================ + + +def test_build_batches_empty_list() -> None: + """Test _build_batches with empty event list.""" + result = ViewPropertyBase._build_batches([]) + assert result == {} + + +def test_build_batches_no_action_events() -> None: + """Test _build_batches with no ActionEvents.""" + events = [ + message_event("Hello"), + create_observation_event("call_1"), + ] + result = ViewPropertyBase._build_batches(events) + assert result == {} + + +def test_build_batches_single_action() -> None: + """Test _build_batches with single ActionEvent.""" + action = create_action_event("resp_1", "call_1") + events = [action] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 1 + assert "resp_1" in result + assert result["resp_1"] == [action.id] + + +def test_build_batches_multiple_actions_same_response() -> None: + """Test _build_batches with multiple actions from same LLM response.""" + action1 = create_action_event("resp_1", "call_1") + action2 = create_action_event("resp_1", "call_2") + action3 = create_action_event("resp_1", "call_3") + events = [action1, action2, action3] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 1 + assert "resp_1" in result + assert result["resp_1"] == [action1.id, action2.id, action3.id] + + +def test_build_batches_multiple_actions_different_responses() -> None: + """Test _build_batches with actions from different LLM responses.""" + action1 = create_action_event("resp_1", "call_1") + action2 = create_action_event("resp_2", "call_2") + action3 = create_action_event("resp_3", "call_3") + events = [action1, action2, action3] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 3 + assert result["resp_1"] == [action1.id] + assert result["resp_2"] == [action2.id] + assert result["resp_3"] == [action3.id] + + +def test_build_batches_mixed_event_types() -> None: + """Test _build_batches with mixed event types.""" + msg = message_event("User message") + action1 = create_action_event("resp_1", "call_1") + obs1 = create_observation_event("call_1") + action2 = create_action_event("resp_2", "call_2") + obs2 = create_observation_event("call_2") + + events = [msg, action1, obs1, action2, obs2] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 2 + assert result["resp_1"] == [action1.id] + assert result["resp_2"] == [action2.id] + + +def test_build_batches_parallel_calls() -> None: + """Test _build_batches with parallel tool calls (same llm_response_id).""" + action1 = create_action_event("resp_1", "call_1a") + action2 = create_action_event("resp_1", "call_1b") + action3 = create_action_event("resp_1", "call_1c") + obs1 = create_observation_event("call_1a") + obs2 = create_observation_event("call_1b") + obs3 = create_observation_event("call_1c") + + events = [action1, action2, action3, obs1, obs2, obs3] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 1 + assert result["resp_1"] == [action1.id, action2.id, action3.id] + + +def test_build_batches_interleaved_batches() -> None: + """Test _build_batches with interleaved batches.""" + action1a = create_action_event("resp_1", "call_1a") + action2a = create_action_event("resp_2", "call_2a") + action1b = create_action_event("resp_1", "call_1b") + action2b = create_action_event("resp_2", "call_2b") + + events = [action1a, action2a, action1b, action2b] + + result = ViewPropertyBase._build_batches(events) + + assert len(result) == 2 + assert result["resp_1"] == [action1a.id, action1b.id] + assert result["resp_2"] == [action2a.id, action2b.id] + + +# ============================================================================ +# Tests for _build_event_id_to_index() utility function +# ============================================================================ + + +def test_build_event_id_to_index_empty_list() -> None: + """Test _build_event_id_to_index with empty event list.""" + result = ViewPropertyBase._build_event_id_to_index([]) + assert result == {} + + +def test_build_event_id_to_index_single_event() -> None: + """Test _build_event_id_to_index with single event.""" + event = message_event("Hello") + events = [event] + + result = ViewPropertyBase._build_event_id_to_index(events) + + assert len(result) == 1 + assert result[event.id] == 0 + + +def test_build_event_id_to_index_multiple_events() -> None: + """Test _build_event_id_to_index with multiple events.""" + event1 = message_event("Hello") + event2 = create_action_event("resp_1", "call_1") + event3 = create_observation_event("call_1") + event4 = message_event("Goodbye") + + events = [event1, event2, event3, event4] + + result = ViewPropertyBase._build_event_id_to_index(events) + + assert len(result) == 4 + assert result[event1.id] == 0 + assert result[event2.id] == 1 + assert result[event3.id] == 2 + assert result[event4.id] == 3 + + +def test_build_event_id_to_index_preserves_order() -> None: + """Test that _build_event_id_to_index preserves event order.""" + events = [create_action_event(f"resp_{i}", f"call_{i}") for i in range(10)] + + result = ViewPropertyBase._build_event_id_to_index(events) + + assert len(result) == 10 + for idx, event in enumerate(events): + assert result[event.id] == idx + + +def test_build_event_id_to_index_different_event_types() -> None: + """Test _build_event_id_to_index with different event types.""" + msg1 = message_event("User 1") + action1 = create_action_event("resp_1", "call_1") + action2 = create_action_event("resp_1", "call_2") # Parallel call + obs1 = create_observation_event("call_1") + obs2 = create_observation_event("call_2") + msg2 = message_event("User 2") + + events = [msg1, action1, action2, obs1, obs2, msg2] + + result = ViewPropertyBase._build_event_id_to_index(events) + + assert len(result) == 6 + assert result[msg1.id] == 0 + assert result[action1.id] == 1 + assert result[action2.id] == 2 + assert result[obs1.id] == 3 + assert result[obs2.id] == 4 + assert result[msg2.id] == 5 + + +def test_build_event_id_to_index_unique_ids() -> None: + """Test that each event has a unique ID and index.""" + events = [ + message_event("Message 1"), + create_action_event("resp_1", "call_1"), + create_observation_event("call_1"), + message_event("Message 2"), + ] + + result = ViewPropertyBase._build_event_id_to_index(events) + + # All event IDs should be unique + assert len(result) == len(events) + + # All indices should be unique and in range [0, len(events)) + indices = list(result.values()) + assert len(set(indices)) == len(indices) + assert min(indices) == 0 + assert max(indices) == len(events) - 1 diff --git a/tests/sdk/context/test_view.py b/tests/sdk/context/view/test_view.py similarity index 60% rename from tests/sdk/context/test_view.py rename to tests/sdk/context/view/test_view.py index 6ab90ca97a..5353155959 100644 --- a/tests/sdk/context/test_view.py +++ b/tests/sdk/context/view/test_view.py @@ -1,5 +1,4 @@ from typing import cast -from unittest.mock import create_autospec from openhands.sdk.context.view import View from openhands.sdk.event.base import Event @@ -9,9 +8,7 @@ CondensationSummaryEvent, ) from openhands.sdk.event.llm_convertible import ( - ActionEvent, MessageEvent, - ObservationEvent, ) from openhands.sdk.llm import Message, TextContent @@ -593,379 +590,3 @@ def test_summary_event_with_zero_offset() -> None: assert view.summary_event is not None assert view.summary_event.summary == "Summary at beginning" assert view[0] == view.summary_event # Summary is first event - - -# Tests for unmatched tool call filtering functionality moved from CondenserBase - - -def test_filter_unmatched_tool_calls_empty_list() -> None: - """Test filter_unmatched_tool_calls with empty event list.""" - result = View.filter_unmatched_tool_calls([]) - assert result == [] - - -def test_filter_unmatched_tool_calls_no_tool_events() -> None: - """Test filter_unmatched_tool_calls with no tool events.""" - # Create mock non-tool events - message_event_1 = create_autospec(MessageEvent, instance=True) - message_event_1.id = "msg_1" - message_event_2 = create_autospec(MessageEvent, instance=True) - message_event_2.id = "msg_2" - - events = [message_event_1, message_event_2] - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # All non-tool events should be kept - assert len(result) == 2 - assert message_event_1 in result - assert message_event_2 in result - - -def test_filter_unmatched_tool_calls_matched_pairs() -> None: - """Test filter_unmatched_tool_calls with matched tool call pairs.""" - # Create mock events - message_event = create_autospec(MessageEvent, instance=True) - message_event.id = "msg_1" - - # Matched pair 1 - action_event_1 = create_autospec(ActionEvent, instance=True) - action_event_1.tool_call_id = "call_1" - action_event_1.id = "action_1" - action_event_1.llm_response_id = "response_1" - - observation_event_1 = create_autospec(ObservationEvent, instance=True) - observation_event_1.tool_call_id = "call_1" - observation_event_1.id = "obs_1" - - # Matched pair 2 - action_event_2 = create_autospec(ActionEvent, instance=True) - action_event_2.tool_call_id = "call_2" - action_event_2.id = "action_2" - action_event_2.llm_response_id = "response_2" - - observation_event_2 = create_autospec(ObservationEvent, instance=True) - observation_event_2.tool_call_id = "call_2" - observation_event_2.id = "obs_2" - - events = [ - message_event, - action_event_1, - observation_event_1, - action_event_2, - observation_event_2, - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # All events should be kept (all tool calls are matched) - assert len(result) == 5 - assert message_event in result - assert action_event_1 in result - assert observation_event_1 in result - assert action_event_2 in result - assert observation_event_2 in result - - -def test_filter_unmatched_tool_calls_unmatched_action() -> None: - """Test filter_unmatched_tool_calls with unmatched ActionEvent.""" - # Create mock events - message_event = create_autospec(MessageEvent, instance=True) - message_event.id = "msg_1" - - # Matched pair - action_event_matched = create_autospec(ActionEvent, instance=True) - action_event_matched.tool_call_id = "call_1" - action_event_matched.id = "action_1" - action_event_matched.llm_response_id = "response_1" - - observation_event_matched = create_autospec(ObservationEvent, instance=True) - observation_event_matched.tool_call_id = "call_1" - observation_event_matched.id = "obs_1" - - # Unmatched ActionEvent - action_event_unmatched = create_autospec(ActionEvent, instance=True) - action_event_unmatched.tool_call_id = "call_2" - action_event_unmatched.id = "action_2" - action_event_unmatched.llm_response_id = "response_2" - - events = [ - message_event, - action_event_matched, - observation_event_matched, - action_event_unmatched, - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Should keep: message_event, matched pair - # Should filter out: unmatched ActionEvent - assert len(result) == 3 - assert message_event in result - assert action_event_matched in result - assert observation_event_matched in result - assert action_event_unmatched not in result - - -def test_filter_unmatched_tool_calls_unmatched_observation() -> None: - """Test filter_unmatched_tool_calls with unmatched ObservationEvent.""" - # Create mock events - message_event = create_autospec(MessageEvent, instance=True) - message_event.id = "msg_1" - - # Matched pair - action_event_matched = create_autospec(ActionEvent, instance=True) - action_event_matched.tool_call_id = "call_1" - action_event_matched.id = "action_1" - action_event_matched.llm_response_id = "response_1" - - observation_event_matched = create_autospec(ObservationEvent, instance=True) - observation_event_matched.tool_call_id = "call_1" - observation_event_matched.id = "obs_1" - - # Unmatched ObservationEvent - observation_event_unmatched = create_autospec(ObservationEvent, instance=True) - observation_event_unmatched.tool_call_id = "call_2" - observation_event_unmatched.id = "obs_2" - - events = [ - message_event, - action_event_matched, - observation_event_matched, - observation_event_unmatched, - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Should keep: message_event, matched pair - # Should filter out: unmatched ObservationEvent - assert len(result) == 3 - assert message_event in result - assert action_event_matched in result - assert observation_event_matched in result - assert observation_event_unmatched not in result - - -def test_filter_unmatched_tool_calls_mixed_scenario() -> None: - """Test filter_unmatched_tool_calls with complex mixed scenario.""" - # Create mock events - message_event_1 = create_autospec(MessageEvent, instance=True) - message_event_1.id = "msg_1" - message_event_2 = create_autospec(MessageEvent, instance=True) - message_event_2.id = "msg_2" - - # Matched pair 1 - action_event_1 = create_autospec(ActionEvent, instance=True) - action_event_1.tool_call_id = "call_1" - action_event_1.id = "action_1" - action_event_1.llm_response_id = "response_1" - - observation_event_1 = create_autospec(ObservationEvent, instance=True) - observation_event_1.tool_call_id = "call_1" - observation_event_1.id = "obs_1" - - # Unmatched ActionEvent - action_event_unmatched = create_autospec(ActionEvent, instance=True) - action_event_unmatched.tool_call_id = "call_2" - action_event_unmatched.id = "action_unmatched" - action_event_unmatched.llm_response_id = "response_2" - - # Unmatched ObservationEvent - observation_event_unmatched = create_autospec(ObservationEvent, instance=True) - observation_event_unmatched.tool_call_id = "call_3" - observation_event_unmatched.id = "obs_unmatched" - - # Matched pair 2 - action_event_2 = create_autospec(ActionEvent, instance=True) - action_event_2.tool_call_id = "call_4" - action_event_2.id = "action_2" - action_event_2.llm_response_id = "response_3" - - observation_event_2 = create_autospec(ObservationEvent, instance=True) - observation_event_2.tool_call_id = "call_4" - observation_event_2.id = "obs_2" - - events = [ - message_event_1, - action_event_1, - observation_event_1, - action_event_unmatched, - observation_event_unmatched, - message_event_2, - action_event_2, - observation_event_2, - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Should keep: message events and matched pairs - # Should filter out: unmatched action and observation events - assert len(result) == 6 - assert message_event_1 in result - assert message_event_2 in result - assert action_event_1 in result - assert observation_event_1 in result - assert action_event_2 in result - assert observation_event_2 in result - assert action_event_unmatched not in result - assert observation_event_unmatched not in result - - -def test_filter_unmatched_tool_calls_none_tool_call_id() -> None: - """Test filter_unmatched_tool_calls with None tool_call_id.""" - # Create mock events with None tool_call_id - action_event_none = create_autospec(ActionEvent, instance=True) - action_event_none.tool_call_id = None - action_event_none.id = "action_none" - action_event_none.llm_response_id = "response_1" - - observation_event_none = create_autospec(ObservationEvent, instance=True) - observation_event_none.tool_call_id = None - observation_event_none.id = "obs_none" - - # Valid matched pair - action_event_valid = create_autospec(ActionEvent, instance=True) - action_event_valid.tool_call_id = "call_1" - action_event_valid.id = "action_valid" - action_event_valid.llm_response_id = "response_2" - - observation_event_valid = create_autospec(ObservationEvent, instance=True) - observation_event_valid.tool_call_id = "call_1" - observation_event_valid.id = "obs_valid" - - events = [ - action_event_none, - observation_event_none, - action_event_valid, - observation_event_valid, - ] - - result = View.filter_unmatched_tool_calls(events) # type: ignore - - # Should keep only the valid matched pair - # Events with None tool_call_id should be filtered out - assert len(result) == 2 - assert action_event_valid in result - assert observation_event_valid in result - assert action_event_none not in result - assert observation_event_none not in result - - -def test_get_action_tool_call_ids() -> None: - """Test _get_action_tool_call_ids helper method.""" - # Create mock events - message_event = create_autospec(MessageEvent, instance=True) - - action_event_1 = create_autospec(ActionEvent, instance=True) - action_event_1.tool_call_id = "call_1" - - action_event_2 = create_autospec(ActionEvent, instance=True) - action_event_2.tool_call_id = "call_2" - - action_event_none = create_autospec(ActionEvent, instance=True) - action_event_none.tool_call_id = None - - observation_event = create_autospec(ObservationEvent, instance=True) - observation_event.tool_call_id = "call_3" - - events = [ - message_event, - action_event_1, - action_event_2, - action_event_none, - observation_event, - ] - - result = View._get_action_tool_call_ids(events) # type: ignore - - # Should only include tool_call_ids from ActionEvents with non-None tool_call_id - assert result == {"call_1", "call_2"} - - -def test_get_observation_tool_call_ids() -> None: - """Test _get_observation_tool_call_ids helper method.""" - # Create mock events - message_event = create_autospec(MessageEvent, instance=True) - - observation_event_1 = create_autospec(ObservationEvent, instance=True) - observation_event_1.tool_call_id = "call_1" - - observation_event_2 = create_autospec(ObservationEvent, instance=True) - observation_event_2.tool_call_id = "call_2" - - observation_event_none = create_autospec(ObservationEvent, instance=True) - observation_event_none.tool_call_id = None - - action_event = create_autospec(ActionEvent, instance=True) - action_event.tool_call_id = "call_3" - - events = [ - message_event, - observation_event_1, - observation_event_2, - observation_event_none, - action_event, - ] - - result = View._get_observation_tool_call_ids(events) # type: ignore - - # Should only include tool_call_ids from ObservationEvents with non-None - # tool_call_id - assert result == {"call_1", "call_2"} - - -def test_should_keep_event_observation_event() -> None: - """Test _should_keep_event with ObservationEvent.""" - observation_event = create_autospec(ObservationEvent, instance=True) - observation_event.tool_call_id = "call_1" - - action_tool_call_ids = {"call_1", "call_2"} - observation_tool_call_ids = {"call_1", "call_3"} - - # Should keep because tool_call_id is in action_tool_call_ids - result = View._should_keep_event( - observation_event, action_tool_call_ids, observation_tool_call_ids - ) - assert result is True - - # Should not keep because tool_call_id is not in action_tool_call_ids - action_tool_call_ids_no_match = {"call_2", "call_3"} - result = View._should_keep_event( - observation_event, action_tool_call_ids_no_match, observation_tool_call_ids - ) - assert result is False - - -def test_should_keep_event_action_event() -> None: - """Test _should_keep_event with ActionEvent.""" - action_event = create_autospec(ActionEvent, instance=True) - action_event.tool_call_id = "call_1" - - action_tool_call_ids = {"call_1", "call_2"} - observation_tool_call_ids = {"call_1", "call_3"} - - # Should keep because tool_call_id is in observation_tool_call_ids - result = View._should_keep_event( - action_event, action_tool_call_ids, observation_tool_call_ids - ) - assert result is True - - # Should not keep because tool_call_id is not in observation_tool_call_ids - observation_tool_call_ids_no_match = {"call_2", "call_3"} - result = View._should_keep_event( - action_event, action_tool_call_ids, observation_tool_call_ids_no_match - ) - assert result is False - - -def test_should_keep_event_other_event_types() -> None: - """Test _should_keep_event with non-tool event types.""" - message_event = create_autospec(MessageEvent, instance=True) - - action_tool_call_ids = {"call_1"} - observation_tool_call_ids = {"call_2"} - - # Should always keep non-tool events - result = View._should_keep_event( - message_event, action_tool_call_ids, observation_tool_call_ids - ) - assert result is True diff --git a/tests/sdk/context/test_view_batch_atomicity.py b/tests/sdk/context/view/test_view_batch_atomicity.py similarity index 100% rename from tests/sdk/context/test_view_batch_atomicity.py rename to tests/sdk/context/view/test_view_batch_atomicity.py diff --git a/tests/sdk/context/test_view_condensation_batch_atomicity.py b/tests/sdk/context/view/test_view_condensation_batch_atomicity.py similarity index 100% rename from tests/sdk/context/test_view_condensation_batch_atomicity.py rename to tests/sdk/context/view/test_view_condensation_batch_atomicity.py diff --git a/tests/sdk/context/test_view_manipulation_indices.py b/tests/sdk/context/view/test_view_manipulation_indices.py similarity index 95% rename from tests/sdk/context/test_view_manipulation_indices.py rename to tests/sdk/context/view/test_view_manipulation_indices.py index 9c8b2789fc..59c8341f24 100644 --- a/tests/sdk/context/test_view_manipulation_indices.py +++ b/tests/sdk/context/view/test_view_manipulation_indices.py @@ -76,7 +76,7 @@ def message_event(content: str) -> MessageEvent: def test_empty_list() -> None: """Test manipulation_indices with empty event list.""" view = View.from_events([]) - assert view.manipulation_indices == [0] + assert view.manipulation_indices == {0} def test_single_message_event() -> None: @@ -87,7 +87,7 @@ def test_single_message_event() -> None: # Should have boundaries before and after the single message assert 0 in view.manipulation_indices assert 1 in view.manipulation_indices - assert view.manipulation_indices == [0, 1] + assert view.manipulation_indices == {0, 1} def test_multiple_message_events() -> None: @@ -100,7 +100,7 @@ def test_multiple_message_events() -> None: view = View.from_events(events) # Each message is its own atomic unit, so boundaries exist between all of them - assert view.manipulation_indices == [0, 1, 2, 3] + assert view.manipulation_indices == {0, 1, 2, 3} def test_single_action_observation_pair() -> None: @@ -112,7 +112,7 @@ def test_single_action_observation_pair() -> None: indices = View.from_events(events).manipulation_indices # The pair is an atomic unit, so boundaries are only at start and end - assert indices == [0, 2] + assert indices == {0, 2} def test_action_observation_with_message_events() -> None: @@ -126,7 +126,7 @@ def test_action_observation_with_message_events() -> None: indices = View.from_events(events).manipulation_indices # Boundaries: [0 msg1 1 (action+obs) 3 msg2 4] - assert indices == [0, 1, 3, 4] + assert indices == {0, 1, 3, 4} def test_batch_of_actions_simple() -> None: @@ -148,7 +148,7 @@ def test_batch_of_actions_simple() -> None: # All actions are part of the same batch, and observations extend the range # The entire batch (actions + observations) is one atomic unit - assert indices == [0, 6] + assert indices == {0, 6} def test_batch_with_interleaved_observations() -> None: @@ -164,7 +164,7 @@ def test_batch_with_interleaved_observations() -> None: indices = View.from_events(events).manipulation_indices # Still one atomic unit because actions share llm_response_id - assert indices == [0, 4] + assert indices == {0, 4} def test_multiple_separate_batches() -> None: @@ -194,7 +194,7 @@ def test_multiple_separate_batches() -> None: indices = View.from_events(events).manipulation_indices # Two atomic units: batch1 (indices 0-3) and batch2 (indices 4-7) - assert indices == [0, 4, 8] + assert indices == {0, 4, 8} def test_batches_separated_by_messages() -> None: @@ -217,7 +217,7 @@ def test_batches_separated_by_messages() -> None: indices = View.from_events(events).manipulation_indices # [0 msg1 1 (batch1: action1,action2,obs1,obs2) 5 msg2 6 (batch2) 8 msg3 9] - assert indices == [0, 1, 5, 6, 8, 9] + assert indices == {0, 1, 5, 6, 8, 9} def test_single_action_in_batch() -> None: @@ -229,7 +229,7 @@ def test_single_action_in_batch() -> None: indices = View.from_events(events).manipulation_indices # Single-action batch is still an atomic unit - assert indices == [0, 2] + assert indices == {0, 2} def test_complex_interleaved_scenario() -> None: @@ -294,7 +294,7 @@ def test_complex_interleaved_scenario() -> None: # - 7: after msg3, before batch2 # - 9: after batch2 - assert indices == [0, 1, 6, 7, 9] + assert indices == {0, 1, 6, 7, 9} def test_observations_extend_batch_range() -> None: @@ -313,7 +313,7 @@ def test_observations_extend_batch_range() -> None: # Batch includes actions 0-1 and observations 3-4 # Message at 2 falls within the batch range, so treated as part of it # Range: min=0, max=4 - assert indices == [0, 5] + assert indices == {0, 5} def test_batch_with_all_observations() -> None: @@ -332,7 +332,7 @@ def test_batch_with_all_observations() -> None: indices = view.manipulation_indices # The batch is one atomic unit containing both action-observation pairs - assert indices == [0, 4] + assert indices == {0, 4} def test_interleaved_batches_and_messages() -> None: @@ -353,7 +353,7 @@ def test_interleaved_batches_and_messages() -> None: indices = View.from_events(events).manipulation_indices # [0 msg1 1 batch1 3 msg2 4 batch2 6 msg3 7] - assert indices == [0, 1, 3, 4, 6, 7] + assert indices == {0, 1, 3, 4, 6, 7} def test_three_action_batch() -> None: @@ -370,7 +370,7 @@ def test_three_action_batch() -> None: indices = View.from_events(events).manipulation_indices # All part of one batch - assert indices == [0, 6] + assert indices == {0, 6} def test_consecutive_atomic_units() -> None: @@ -387,7 +387,7 @@ def test_consecutive_atomic_units() -> None: indices = View.from_events(events).manipulation_indices # [0 msg1 1 msg2 2 batch 4 msg3 5] - assert indices == [0, 1, 2, 4, 5] + assert indices == {0, 1, 2, 4, 5} # Verify atomic units: # events[0:1] = [msg1] @@ -411,7 +411,7 @@ def test_forgetting_range_selection() -> None: indices = View.from_events(events).manipulation_indices # [0 msg1 1 batch 5 msg2 6] - assert indices == [0, 1, 5, 6] + assert indices == {0, 1, 5, 6} # To forget the batch: forget events[1:5] # That would remove action1, action2, obs1, obs2 as an atomic unit