diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index 9b9058ec40..24aeb8e589 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -146,7 +146,7 @@ def step( # Prepare LLM messages using the utility function _messages_or_condensation = prepare_llm_messages( - state.events, condenser=self.condenser + state.events, condenser=self.condenser, llm=self.llm ) # Process condensation event before agent sampels another action diff --git a/openhands-sdk/openhands/sdk/agent/utils.py b/openhands-sdk/openhands/sdk/agent/utils.py index 65943658cd..273160d407 100644 --- a/openhands-sdk/openhands/sdk/agent/utils.py +++ b/openhands-sdk/openhands/sdk/agent/utils.py @@ -117,6 +117,7 @@ def prepare_llm_messages( events: Sequence[Event], condenser: None = None, additional_messages: list[Message] | None = None, + llm: LLM | None = None, ) -> list[Message]: ... @@ -125,6 +126,7 @@ def prepare_llm_messages( events: Sequence[Event], condenser: CondenserBase, additional_messages: list[Message] | None = None, + llm: LLM | None = None, ) -> list[Message] | Condensation: ... @@ -132,6 +134,7 @@ def prepare_llm_messages( events: Sequence[Event], condenser: CondenserBase | None = None, additional_messages: list[Message] | None = None, + llm: LLM | None = None, ) -> list[Message] | Condensation: """Prepare LLM messages from conversation context. @@ -140,13 +143,15 @@ def prepare_llm_messages( It handles condensation internally and calls the callback when needed. Args: - state: The conversation state containing events + events: Sequence of events to prepare messages from condenser: Optional condenser for handling context window limits additional_messages: Optional additional messages to append - on_event: Optional callback for handling condensation events + llm: Optional LLM instance from the agent, passed to condenser for + token counting or other LLM features Returns: - List of messages ready for LLM completion + List of messages ready for LLM completion, or a Condensation event + if condensation is needed Raises: RuntimeError: If condensation is needed but no callback is provided @@ -160,7 +165,7 @@ def prepare_llm_messages( # produce a list of events, exactly as expected, or a # new condensation that needs to be processed if condenser is not None: - condensation_result = condenser.condense(view) + condensation_result = condenser.condense(view, agent_llm=llm) match condensation_result: case View(): diff --git a/openhands-sdk/openhands/sdk/context/condenser/base.py b/openhands-sdk/openhands/sdk/context/condenser/base.py index 74082926ba..938b3495d9 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/base.py +++ b/openhands-sdk/openhands/sdk/context/condenser/base.py @@ -3,6 +3,7 @@ from openhands.sdk.context.view import View from openhands.sdk.event.condenser import Condensation +from openhands.sdk.llm import LLM from openhands.sdk.utils.models import ( DiscriminatedUnionMixin, ) @@ -28,7 +29,7 @@ class CondenserBase(DiscriminatedUnionMixin, ABC): """ @abstractmethod - def condense(self, view: View) -> View | Condensation: + def condense(self, view: View, agent_llm: LLM | None = None) -> View | Condensation: """Condense a sequence of events into a potentially smaller list. New condenser strategies should override this method to implement their own @@ -37,6 +38,8 @@ def condense(self, view: View) -> View | Condensation: Args: view: A view of the history containing all events that should be condensed. + agent_llm: LLM instance used by the agent. Condensers use this for token + counting purposes. Defaults to None. Returns: View | Condensation: A condensed view of the events or an event indicating @@ -77,18 +80,20 @@ class RollingCondenser(PipelinableCondenserBase, ABC): """ @abstractmethod - def should_condense(self, view: View) -> bool: + def should_condense(self, view: View, agent_llm: LLM | None = None) -> bool: """Determine if a view should be condensed.""" @abstractmethod - def get_condensation(self, view: View) -> Condensation: + def get_condensation( + self, view: View, agent_llm: LLM | None = None + ) -> Condensation: """Get the condensation from a view.""" - def condense(self, view: View) -> View | Condensation: + def condense(self, view: View, agent_llm: LLM | None = None) -> View | Condensation: # If we trigger the condenser-specific condensation threshold, compute and # return the condensation. - if self.should_condense(view): - return self.get_condensation(view) + if self.should_condense(view, agent_llm=agent_llm): + return self.get_condensation(view, agent_llm=agent_llm) # Otherwise we're safe to just return the view. else: diff --git a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py index 610879577e..39fbe3d9a5 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py @@ -1,19 +1,43 @@ import os +from collections.abc import Sequence +from enum import Enum from pydantic import Field, model_validator from openhands.sdk.context.condenser.base import RollingCondenser +from openhands.sdk.context.condenser.utils import ( + get_suffix_length_for_token_reduction, + get_total_token_count, +) from openhands.sdk.context.prompts import render_template from openhands.sdk.context.view import View +from openhands.sdk.event.base import LLMConvertibleEvent from openhands.sdk.event.condenser import Condensation from openhands.sdk.event.llm_convertible import MessageEvent from openhands.sdk.llm import LLM, Message, TextContent from openhands.sdk.observability.laminar import observe +class Reason(Enum): + """Reasons for condensation.""" + + REQUEST = "request" + TOKENS = "tokens" + EVENTS = "events" + + class LLMSummarizingCondenser(RollingCondenser): + """LLM-based condenser that summarizes forgotten events. + + Uses an independent LLM (stored in the `llm` attribute) for generating summaries + of forgotten events. The optional `agent_llm` parameter passed to condense() is + the LLM used by the agent for token counting purposes, and you should not assume + it is the same as the one defined in this condenser. + """ + llm: LLM max_size: int = Field(default=120, gt=0) + max_tokens: int | None = None keep_first: int = Field(default=4, ge=0) @model_validator(mode="after") @@ -29,23 +53,47 @@ def validate_keep_first_vs_max_size(self): def handles_condensation_requests(self) -> bool: return True - def should_condense(self, view: View) -> bool: - if view.unhandled_condensation_request: - return True - return len(view) > self.max_size + def get_condensation_reasons( + self, view: View, agent_llm: LLM | None = None + ) -> set[Reason]: + """Determine the reasons why the view should be condensed. + + Args: + view: The current view to evaluate. + agent_llm: The LLM used by the agent. Required if token counting is needed. + + Returns: + A set of Reason enums indicating why condensation is needed. + """ + reasons = set() - @observe(ignore_inputs=["view"]) - def get_condensation(self, view: View) -> Condensation: - head = view[: self.keep_first] - target_size = self.max_size // 2 + # Reason 1: Unhandled condensation request. The view handles the detection of + # these requests while processing the event stream. if view.unhandled_condensation_request: - # Condensation triggered by a condensation request - # should be calculated based on the view size. - target_size = len(view) // 2 - # Number of events to keep from the tail -- target size, minus however many - # prefix events from the head, minus one for the summarization event - events_from_tail = target_size - len(head) - 1 + reasons.add(Reason.REQUEST) + + # Reason 2: Token limit is provided and exceeded. + if self.max_tokens and agent_llm: + total_tokens = get_total_token_count(view.events, agent_llm) + if total_tokens > self.max_tokens: + reasons.add(Reason.TOKENS) + + # Reason 3: View exceeds maximum size in number of events. + if len(view) > self.max_size: + reasons.add(Reason.EVENTS) + + return reasons + def should_condense(self, view: View, agent_llm: LLM | None = None) -> bool: + reasons = self.get_condensation_reasons(view, agent_llm) + return reasons != set() + + def _get_summary_event_content(self, view: View) -> str: + """Extract the text content from the summary event in the view, if any. + + If there is no summary event or it does not contain text content, returns an + empty string. + """ summary_event_content: str = "" summary_event = view.summary_event @@ -54,9 +102,23 @@ def get_condensation(self, view: View) -> Condensation: if isinstance(message_content, TextContent): summary_event_content = message_content.text - # Identify events to be forgotten (those not in head or tail) - forgotten_events = view[self.keep_first : -events_from_tail] + return summary_event_content + def _generate_condensation( + self, + summary_event_content: str, + forgotten_events: Sequence[LLMConvertibleEvent], + ) -> Condensation: + """Generate a condensation by using the condenser's LLM to summarize forgotten + events. + + Args: + summary_event_content: The content of the previous summary event. + forgotten_events: The list of events to be summarized. + + Returns: + Condensation: The generated condensation object. + """ # Convert events to strings for the template event_strings = [str(forgotten_event) for forgotten_event in forgotten_events] @@ -87,3 +149,72 @@ def get_condensation(self, view: View) -> Condensation: summary_offset=self.keep_first, llm_response_id=llm_response.id, ) + + def _get_forgotten_events( + self, view: View, agent_llm: LLM | None = None + ) -> Sequence[LLMConvertibleEvent]: + """Identify events to be forgotten. + + Relies on the condensation reasons to determine how many events we need to drop + in order to maintain our resource constraints. + + Args: + view: The current view from which to identify forgotten events. + agent_llm: The LLM used by the agent, required for token-based calculations. + + Returns: + A sequence of events to be forgotten. + """ + reasons = self.get_condensation_reasons(view, agent_llm=agent_llm) + assert reasons != set(), "No condensation reasons found." + + suffix_events_to_keep: set[int] = set() + + if Reason.REQUEST in reasons: + target_size = len(view) // 2 + suffix_events_to_keep.add(target_size - self.keep_first - 1) + + if Reason.EVENTS in reasons: + target_size = self.max_size // 2 + suffix_events_to_keep.add(target_size - self.keep_first - 1) + + if Reason.TOKENS in reasons: + # Compute the number of tokens we need to eliminate to be under half the + # max_tokens value. We know max_tokens and the agent LLM are not None here + # because we can't have Reason.TOKENS without them. + assert self.max_tokens is not None + assert agent_llm is not None + + total_tokens = get_total_token_count(view.events, agent_llm) + tokens_to_reduce = total_tokens - (self.max_tokens // 2) + + suffix_length = get_suffix_length_for_token_reduction( + events=view.events[self.keep_first :], + llm=agent_llm, + token_reduction=tokens_to_reduce, + ) + + suffix_events_to_keep.add(suffix_length) + + # We might have multiple reasons to condense, so pick the strictest condensation + # to ensure all resource constraints are met. + events_from_tail = min(suffix_events_to_keep) + + # Identify events to be forgotten (those not in head or tail) + if events_from_tail == 0: + return view[self.keep_first :] + return view[self.keep_first : -events_from_tail] + + @observe(ignore_inputs=["view", "agent_llm"]) + def get_condensation( + self, view: View, agent_llm: LLM | None = None + ) -> Condensation: + # The condensation is dependent on the events we want to drop and the previous + # summary. + summary_event_content = self._get_summary_event_content(view) + forgotten_events = self._get_forgotten_events(view, agent_llm=agent_llm) + + return self._generate_condensation( + summary_event_content=summary_event_content, + forgotten_events=forgotten_events, + ) diff --git a/openhands-sdk/openhands/sdk/context/condenser/no_op_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/no_op_condenser.py index b4a3053fe4..12cf6be7e4 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/no_op_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/no_op_condenser.py @@ -1,6 +1,7 @@ from openhands.sdk.context.condenser.base import CondenserBase from openhands.sdk.context.view import View from openhands.sdk.event.condenser import Condensation +from openhands.sdk.llm import LLM class NoOpCondenser(CondenserBase): @@ -9,5 +10,5 @@ class NoOpCondenser(CondenserBase): Primarily intended for testing purposes. """ - def condense(self, view: View) -> View | Condensation: + def condense(self, view: View, agent_llm: LLM | None = None) -> View | Condensation: # noqa: ARG002 return view diff --git a/openhands-sdk/openhands/sdk/context/condenser/pipeline_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/pipeline_condenser.py index c02d2e5bc4..454ba11b12 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/pipeline_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/pipeline_condenser.py @@ -1,15 +1,16 @@ from openhands.sdk.context.condenser.base import CondenserBase from openhands.sdk.context.view import View from openhands.sdk.event.condenser import Condensation +from openhands.sdk.llm import LLM class PipelineCondenser(CondenserBase): """A condenser that applies a sequence of condensers in order. All condensers are defined primarily by their `condense` method, which takes a - `View` and returns either a new `View` or a `Condensation` event. That means we can - chain multiple condensers together by passing `View`s along and exiting early if any - condenser returns a `Condensation`. + `View` and an optional `agent_llm` parameter, returning either a new `View` or a + `Condensation` event. That means we can chain multiple condensers together by + passing `View`s along and exiting early if any condenser returns a `Condensation`. For example: @@ -20,20 +21,20 @@ class PipelineCondenser(CondenserBase): CondenserC(...), ]) - result = condenser.condense(view) + result = condenser.condense(view, agent_llm=agent_llm) # Doing the same thing without the pipeline condenser requires more boilerplate # for the monadic chaining other_result = view if isinstance(other_result, View): - other_result = CondenserA(...).condense(other_result) + other_result = CondenserA(...).condense(other_result, agent_llm=agent_llm) if isinstance(other_result, View): - other_result = CondenserB(...).condense(other_result) + other_result = CondenserB(...).condense(other_result, agent_llm=agent_llm) if isinstance(other_result, View): - other_result = CondenserC(...).condense(other_result) + other_result = CondenserC(...).condense(other_result, agent_llm=agent_llm) assert result == other_result """ @@ -41,12 +42,12 @@ class PipelineCondenser(CondenserBase): condensers: list[CondenserBase] """The list of condensers to apply in order.""" - def condense(self, view: View) -> View | Condensation: + def condense(self, view: View, agent_llm: LLM | None = None) -> View | Condensation: result: View | Condensation = view for condenser in self.condensers: if isinstance(result, Condensation): break - result = condenser.condense(result) + result = condenser.condense(result, agent_llm=agent_llm) return result def handles_condensation_requests(self) -> bool: diff --git a/openhands-sdk/openhands/sdk/context/condenser/utils.py b/openhands-sdk/openhands/sdk/context/condenser/utils.py new file mode 100644 index 0000000000..1064bf3b23 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/condenser/utils.py @@ -0,0 +1,149 @@ +from collections.abc import Sequence + +from openhands.sdk.event.base import LLMConvertibleEvent +from openhands.sdk.llm import LLM + + +def get_total_token_count( + events: Sequence[LLMConvertibleEvent], + llm: LLM, +) -> int: + """Calculate the total token count for a list of LLM convertible events. + + This function converts the events to LLM messages and uses the provided LLM + to count the total number of tokens. This is useful for understanding how many + tokens a sequence of events will consume in the context window. + + Args: + events: List of LLM convertible events to count tokens for + llm: The LLM instance to use for token counting (uses the litellm's token + counting utilities) + + Returns: + Total token count for all events converted to messages + + Example: + >>> from openhands.sdk.llm import LLM + >>> from openhands.sdk.event.llm_convertible import MessageEvent + >>> + >>> llm = LLM(model="gpt-4") + >>> events = [ + ... MessageEvent.from_text("Hello, how are you?", source="user"), + ... MessageEvent.from_text("I'm doing great!", source="agent"), + ... ] + >>> token_count = get_total_token_count(events, llm) + >>> print(f"Total tokens: {token_count}") + """ + messages = LLMConvertibleEvent.events_to_messages(list(events)) + return llm.get_token_count(messages) + + +def get_shortest_prefix_above_token_count( + events: Sequence[LLMConvertibleEvent], + llm: LLM, + token_count: int, +) -> int: + """Find the length of the shortest prefix whose token count exceeds the target. + + This function performs a binary search to efficiently find the shortest prefix + of events that, when converted to messages, has a total token count greater than + the specified target token count. + + Args: + events: List of LLM convertible events to search through + llm: The LLM instance to use for token counting (uses the model's tokenizer) + token_count: The target token count threshold + + Returns: + The length of the shortest prefix that exceeds the token count. + Returns 0 if no events are provided. + Returns len(events) if all events combined don't exceed the token count. + + Example: + >>> from openhands.sdk.llm import LLM + >>> from openhands.sdk.event.llm_convertible import MessageEvent + >>> + >>> llm = LLM(model="gpt-4") + >>> events = [ + ... MessageEvent.from_text("Hi", source="user"), + ... MessageEvent.from_text("Hello", source="agent"), + ... MessageEvent.from_text("How are you?", source="user"), + ... MessageEvent.from_text("Great!", source="agent"), + ... ] + >>> prefix_len = get_shortest_prefix_above_token_count(events, llm, 20) + >>> # prefix_len might be 2 if first 2 events exceed 20 tokens + """ + if not events: + return 0 + + # Check if all events combined don't exceed the token count + total_tokens = get_total_token_count(events, llm) + if total_tokens <= token_count: + return len(events) + + # Binary search for the shortest prefix + left, right = 1, len(events) + + while left < right: + mid = (left + right) // 2 + prefix_tokens = get_total_token_count(events[:mid], llm) + + if prefix_tokens > token_count: + # This prefix exceeds the count, try to find a shorter one + right = mid + else: + # This prefix doesn't exceed, we need a longer one + left = mid + 1 + + return left + + +def get_suffix_length_for_token_reduction( + events: Sequence[LLMConvertibleEvent], + llm: LLM, + token_reduction: int, +) -> int: + """Find how many suffix events can be kept while reducing tokens by target amount. + + This function determines the maximum number of events from the end of the list + that can be retained while ensuring the total token count is reduced by at least + the specified amount. It uses the get_shortest_prefix_above_token_count function + to find the prefix that must be removed. + + Args: + events: List of LLM convertible events + llm: The LLM instance to use for token counting (uses the model's tokenizer) + token_reduction: The minimum number of tokens to reduce by + + Returns: + The number of events from the end that can be kept (suffix length). + + Example: + >>> from openhands.sdk.llm import LLM + >>> from openhands.sdk.event.llm_convertible import MessageEvent + >>> + >>> llm = LLM(model="gpt-4") + >>> events = [ + ... MessageEvent.from_text("Event 1", source="user"), + ... MessageEvent.from_text("Event 2", source="agent"), + ... MessageEvent.from_text("Event 3", source="user"), + ... MessageEvent.from_text("Event 4", source="agent"), + ... ] + >>> # Suppose total is 100 tokens, and we want to reduce by 40 tokens + >>> suffix_len = get_suffix_length_for_token_reduction(events, llm, 40) + >>> # suffix_len tells us how many events from the end we can keep + >>> # If first 2 events = 45 tokens, suffix_len = 2 (keep last 2 events) + """ + if not events: + return 0 + + if token_reduction <= 0: + return len(events) + + # Find the shortest prefix that exceeds the token reduction target + prefix_length = get_shortest_prefix_above_token_count(events, llm, token_reduction) + + # The suffix length is what remains after removing the prefix + suffix_length = len(events) - prefix_length + + return suffix_length diff --git a/tests/github_workflows/test_resolve_model_config.py b/tests/github_workflows/test_resolve_model_config.py index 472047a8d9..66be72316d 100644 --- a/tests/github_workflows/test_resolve_model_config.py +++ b/tests/github_workflows/test_resolve_model_config.py @@ -1,26 +1,28 @@ -"""Tests for resolve_model_configs.py GitHub Actions script.""" +"""Tests for resolve_model_config.py GitHub Actions script.""" import sys from pathlib import Path +from unittest.mock import patch -# Import the functions from resolve_model_configs.py +# Import the functions from resolve_model_config.py run_eval_path = Path(__file__).parent.parent.parent / ".github" / "run-eval" sys.path.append(str(run_eval_path)) -from resolve_model_configs import ( # noqa: E402 # type: ignore[import-not-found] +from resolve_model_config import ( # noqa: E402 # type: ignore[import-not-found] find_models_by_id, ) def test_find_models_by_id_single_model(): """Test finding a single model by ID.""" - models = [ - {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, - {"id": "gpt-3.5", "display_name": "GPT-3.5", "llm_config": {}}, - ] + mock_models = { + "gpt-4": {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, + "gpt-3.5": {"id": "gpt-3.5", "display_name": "GPT-3.5", "llm_config": {}}, + } model_ids = ["gpt-4"] - result = find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + result = find_models_by_id(model_ids) assert len(result) == 1 assert result[0]["id"] == "gpt-4" @@ -29,14 +31,15 @@ def test_find_models_by_id_single_model(): def test_find_models_by_id_multiple_models(): """Test finding multiple models by ID.""" - models = [ - {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, - {"id": "gpt-3.5", "display_name": "GPT-3.5", "llm_config": {}}, - {"id": "claude-3", "display_name": "Claude 3", "llm_config": {}}, - ] + mock_models = { + "gpt-4": {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, + "gpt-3.5": {"id": "gpt-3.5", "display_name": "GPT-3.5", "llm_config": {}}, + "claude-3": {"id": "claude-3", "display_name": "Claude 3", "llm_config": {}}, + } model_ids = ["gpt-4", "claude-3"] - result = find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + result = find_models_by_id(model_ids) assert len(result) == 2 assert result[0]["id"] == "gpt-4" @@ -45,14 +48,15 @@ def test_find_models_by_id_multiple_models(): def test_find_models_by_id_preserves_order(): """Test that model order matches the requested IDs order.""" - models = [ - {"id": "a", "display_name": "A", "llm_config": {}}, - {"id": "b", "display_name": "B", "llm_config": {}}, - {"id": "c", "display_name": "C", "llm_config": {}}, - ] + mock_models = { + "a": {"id": "a", "display_name": "A", "llm_config": {}}, + "b": {"id": "b", "display_name": "B", "llm_config": {}}, + "c": {"id": "c", "display_name": "C", "llm_config": {}}, + } model_ids = ["c", "a", "b"] - result = find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + result = find_models_by_id(model_ids) assert len(result) == 3 assert [m["id"] for m in result] == ["c", "a", "b"] @@ -62,33 +66,35 @@ def test_find_models_by_id_missing_model_exits(): """Test that missing model ID causes exit.""" import pytest - models = [ - {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, - ] + mock_models = { + "gpt-4": {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, + } model_ids = ["gpt-4", "nonexistent"] - with pytest.raises(SystemExit) as exc_info: - find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + with pytest.raises(SystemExit) as exc_info: + find_models_by_id(model_ids) assert exc_info.value.code == 1 def test_find_models_by_id_empty_list(): """Test finding models with empty list.""" - models = [ - {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, - ] + mock_models = { + "gpt-4": {"id": "gpt-4", "display_name": "GPT-4", "llm_config": {}}, + } model_ids = [] - result = find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + result = find_models_by_id(model_ids) assert result == [] def test_find_models_by_id_preserves_full_config(): """Test that full model configuration is preserved.""" - models = [ - { + mock_models = { + "custom-model": { "id": "custom-model", "display_name": "Custom Model", "llm_config": { @@ -98,10 +104,11 @@ def test_find_models_by_id_preserves_full_config(): }, "extra_field": "should be preserved", } - ] + } model_ids = ["custom-model"] - result = find_models_by_id(models, model_ids) + with patch.dict("resolve_model_config.MODELS", mock_models): + result = find_models_by_id(model_ids) assert len(result) == 1 assert result[0]["id"] == "custom-model" diff --git a/tests/integration/README.md b/tests/integration/README.md index 303fa762b9..b46722aaa5 100644 --- a/tests/integration/README.md +++ b/tests/integration/README.md @@ -78,6 +78,7 @@ These tests must pass for releases and verify that the agent can successfully co - **t06_github_pr_browsing** - Tests GitHub PR browsing - **t07_interactive_commands** - Tests interactive command handling - **t08_image_file_viewing** - Tests image file viewing capabilities +- **t09_token_condenser** - Tests that token-based condensation works correctly by verifying `get_token_count()` triggers condensation when token limits are exceeded ### Behavior Tests (`b*.py`) - **Optional** @@ -86,3 +87,22 @@ These tests track quality improvements and don't block releases. They verify tha - **b01_no_premature_implementation** - Tests that the agent doesn't start implementing when asked for advice. Uses a real codebase (software-agent-sdk checked out to a historical commit) to test that the agent explores, provides suggestions, and asks clarifying questions instead of immediately creating or editing files. For more details on behavior testing and guidelines for adding new tests, see [BEHAVIOR_TESTS.md](BEHAVIOR_TESTS.md). + +## Writing Integration Tests + +All integration tests inherit from `BaseIntegrationTest` in `base.py`. The base class provides a consistent framework with several customizable properties: + +### Required Methods + +- **`tools`** (property) - List of tools available to the agent +- **`setup()`** - Initialize test-specific setup (create files, etc.) +- **`verify_result()`** - Verify the test succeeded and return `TestResult` + +### Optional Properties + +- **`condenser`** (property) - Optional condenser configuration for the agent (default: `None`) + - Override to test condensation or manage long conversations + - Example: `t09_token_condenser` uses this to verify token counting +- **`max_iteration_per_run`** (property) - Maximum iterations per conversation (default: `100`) + - Override to limit LLM calls for faster tests + - Useful for tests that should complete quickly diff --git a/tests/integration/base.py b/tests/integration/base.py index 3b446a4d17..c7aab944c7 100644 --- a/tests/integration/base.py +++ b/tests/integration/base.py @@ -17,6 +17,7 @@ Message, TextContent, ) +from openhands.sdk.context.condenser import CondenserBase from openhands.sdk.conversation.impl.local_conversation import LocalConversation from openhands.sdk.conversation.visualizer import DefaultConversationVisualizer from openhands.sdk.event.base import Event @@ -89,7 +90,9 @@ def __init__( } self.llm: LLM = LLM(**llm_kwargs, usage_id="test-llm") - self.agent: Agent = Agent(llm=self.llm, tools=self.tools) + self.agent: Agent = Agent( + llm=self.llm, tools=self.tools, condenser=self.condenser + ) self.collected_events: list[Event] = [] self.llm_messages: list[dict[str, Any]] = [] @@ -103,7 +106,7 @@ def __init__( workspace=self.workspace, callbacks=[self.conversation_callback], visualizer=DefaultConversationVisualizer(), # Use default visualizer - max_iteration_per_run=100, + max_iteration_per_run=self.max_iteration_per_run, ) def conversation_callback(self, event: Event): @@ -177,6 +180,24 @@ def tools(self) -> list[Tool]: """List of tools available to the agent.""" pass + @property + def condenser(self) -> CondenserBase | None: + """Optional condenser for the agent. Override to provide a custom condenser. + + Returns: + CondenserBase instance or None (default) + """ + return None + + @property + def max_iteration_per_run(self) -> int: + """Maximum iterations per conversation run. Override to set a custom limit. + + Returns: + Maximum iterations (default: 100) + """ + return 100 + @abstractmethod def setup(self) -> None: """ diff --git a/tests/integration/tests/t09_token_condenser.py b/tests/integration/tests/t09_token_condenser.py new file mode 100644 index 0000000000..ee376e15cf --- /dev/null +++ b/tests/integration/tests/t09_token_condenser.py @@ -0,0 +1,91 @@ +"""Test that agent with token-based condenser successfully triggers condensation. + +This integration test verifies that: +1. An agent can be configured with an LLMSummarizingCondenser using max_tokens +2. The condenser correctly uses get_token_count to measure conversation size +3. Condensation is triggered when token limit is exceeded +""" + +from openhands.sdk import get_logger +from openhands.sdk.context.condenser import LLMSummarizingCondenser +from openhands.sdk.event.condenser import Condensation +from openhands.sdk.tool import Tool, register_tool +from openhands.tools.terminal import TerminalTool +from tests.integration.base import BaseIntegrationTest, TestResult + + +# Instruction designed to generate multiple agent messages +INSTRUCTION = """ +Count from 1 to 1000. For each number, use the echo command to print it along with +a short, unique property of that number (e.g., "1 is the first natural number", +"2 is the only even prime number", etc.). Be creative with your descriptions. + +DO NOT write a script to do this. Instead, interactively call the echo command +1000 times, once for each number from 1 to 1000. + +This won't be efficient -- that is okay, we're using the output as a test for our +context management system. +""" + +logger = get_logger(__name__) + + +class TokenCondenserTest(BaseIntegrationTest): + """Test that agent with token-based condenser triggers condensation.""" + + INSTRUCTION: str = INSTRUCTION + + def __init__(self, *args, **kwargs): + """Initialize test with tracking variables.""" + self.condensation_triggered = False + super().__init__(*args, **kwargs) + + @property + def tools(self) -> list[Tool]: + """List of tools available to the agent.""" + register_tool("TerminalTool", TerminalTool) + return [ + Tool(name="TerminalTool"), + ] + + @property + def condenser(self) -> LLMSummarizingCondenser: + """Configure a token-based condenser with low limits to trigger condensation.""" + # Create a condenser with a low token limit to trigger condensation + # Using max_tokens instead of max_size to test token counting + condenser_llm = self.llm.model_copy(update={"usage_id": "test-condenser-llm"}) + return LLMSummarizingCondenser( + llm=condenser_llm, + max_size=1000, # Set high so it doesn't trigger on event count + max_tokens=5000, # Low token limit to ensure condensation triggers + keep_first=2, + ) + + @property + def max_iteration_per_run(self) -> int: + return 50 + + def conversation_callback(self, event): + """Override callback to detect condensation events.""" + super().conversation_callback(event) + + if isinstance(event, Condensation): + self.condensation_triggered = True + logger.info("Condensation detected! Stopping test early.") + self.conversation.pause() + + def setup(self) -> None: + logger.info(f"Token condenser test: max_tokens={self.condenser.max_tokens}") + + def verify_result(self) -> TestResult: + """Verify that condensation was triggered based on token count.""" + if not self.condensation_triggered: + return TestResult( + success=False, + reason="Condensation not triggered. Token counting may not work.", + ) + + return TestResult( + success=True, + reason="Condensation triggered. Token counting works correctly.", + ) diff --git a/tests/sdk/agent/test_agent_context_window_condensation.py b/tests/sdk/agent/test_agent_context_window_condensation.py index 61d8e63768..cd52da5429 100644 --- a/tests/sdk/agent/test_agent_context_window_condensation.py +++ b/tests/sdk/agent/test_agent_context_window_condensation.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import pytest from pydantic import PrivateAttr @@ -10,6 +12,10 @@ from openhands.sdk.llm.exceptions import LLMContextWindowExceedError +if TYPE_CHECKING: + from openhands.sdk.event.condenser import Condensation + + class RaisingLLM(LLM): _force_responses: bool = PrivateAttr(default=False) @@ -28,7 +34,9 @@ def responses(self, *, messages, tools=None, **kwargs): # type: ignore[override class HandlesRequestsCondenser(CondenserBase): - def condense(self, view: View): # pragma: no cover - trivial passthrough + def condense( + self, view: View, agent_llm: "LLM | None" = None + ) -> "View | Condensation": # pragma: no cover - trivial passthrough return view def handles_condensation_requests(self) -> bool: @@ -90,7 +98,9 @@ def test_agent_logs_warning_when_no_condenser_on_ctx_exceeded( class NoHandlesRequestsCondenser(CondenserBase): """A condenser that doesn't handle condensation requests.""" - def condense(self, view: View): # pragma: no cover - trivial passthrough + def condense( + self, view: View, agent_llm: "LLM | None" = None + ) -> "View | Condensation": # pragma: no cover - trivial passthrough return view def handles_condensation_requests(self) -> bool: diff --git a/tests/sdk/agent/test_agent_utils.py b/tests/sdk/agent/test_agent_utils.py index 4e94344776..bec339bdb5 100644 --- a/tests/sdk/agent/test_agent_utils.py +++ b/tests/sdk/agent/test_agent_utils.py @@ -208,7 +208,7 @@ def test_prepare_llm_messages_with_condenser_returns_view( # Verify results assert result == condensed_messages mock_from_events.assert_called_once_with(sample_events) - mock_condenser.condense.assert_called_once_with(mock_view) + mock_condenser.condense.assert_called_once_with(mock_view, agent_llm=None) mock_events_to_messages.assert_called_once_with(condensed_events) @@ -234,7 +234,7 @@ def test_prepare_llm_messages_with_condenser_returns_condensation( # Verify results assert result == condensation mock_from_events.assert_called_once_with(sample_events) - mock_condenser.condense.assert_called_once_with(mock_view) + mock_condenser.condense.assert_called_once_with(mock_view, agent_llm=None) @patch("openhands.sdk.agent.utils.View.from_events") diff --git a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py index c93d6d7fd2..f4f63d5cdb 100644 --- a/tests/sdk/context/condenser/test_llm_summarizing_condenser.py +++ b/tests/sdk/context/condenser/test_llm_summarizing_condenser.py @@ -6,10 +6,11 @@ from openhands.sdk.context.condenser.llm_summarizing_condenser import ( LLMSummarizingCondenser, + Reason, ) from openhands.sdk.context.view import View from openhands.sdk.event.base import Event -from openhands.sdk.event.condenser import Condensation +from openhands.sdk.event.condenser import Condensation, CondensationRequest from openhands.sdk.event.llm_convertible import MessageEvent from openhands.sdk.llm import ( LLM, @@ -153,7 +154,11 @@ def test_get_condensation_with_previous_summary(mock_llm: LLM) -> None: cast(Any, mock_llm).set_mock_response_content("Updated summary") # Create events with a condensation in the history - events = [message_event(f"Event {i}") for i in range(max_size + 1)] + # Need enough events so that after condensation, the view still exceeds max_size + # Condensation will remove 2 events (events[3] and events[4]) plus itself + # So we need at least max_size + 1 + 3 = 14 events to exceed max_size after + # condensation + events = [message_event(f"Event {i}") for i in range(14)] # Add a condensation to simulate previous summarization condensation = Condensation( @@ -217,5 +222,298 @@ def test_get_condensation_does_not_pass_extra_body(mock_llm: LLM) -> None: # Ensure completion was called without an explicit extra_body kwarg completion_mock = cast(MagicMock, mock_llm.completion) assert completion_mock.call_count == 1 + + +def test_condense_with_agent_llm(mock_llm: LLM) -> None: + """Test that condenser accepts and works with optional agent llm parameter.""" + condenser = LLMSummarizingCondenser(llm=mock_llm, max_size=10, keep_first=2) + + # Create a separate mock for the agent's LLM + agent_llm = MagicMock(spec=LLM) + agent_llm.model = "gpt-4" + + # Prepare a view that triggers condensation + events: list[Event] = [message_event(f"Event {i}") for i in range(12)] + view = View.from_events(events) + + # Call condense with the agent's LLM + result = condenser.condense(view, agent_llm=agent_llm) + assert isinstance(result, Condensation) + + # Verify the condenser still uses its own LLM for summarization + completion_mock = cast(MagicMock, mock_llm.completion) + assert completion_mock.call_count == 1 + + # Agent LLM should not be called for completion (condenser uses its own LLM) + assert not agent_llm.completion.called _, kwargs = completion_mock.call_args assert "extra_body" not in kwargs + + +def test_condense_with_token_limit_exceeded(mock_llm: LLM) -> None: + """Test that condenser triggers on TOKENS reason when token limit is exceeded.""" + max_tokens = 100 + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=1000, max_tokens=max_tokens, keep_first=keep_first + ) + + # Create a separate mock for the agent's LLM with token counting + agent_llm = MagicMock(spec=LLM) + agent_llm.model = "gpt-4" + + # Mock get_token_count to return predictable values based on message content length + def mock_token_count(messages): + # Simple heuristic: count characters in all text content + # Each character = 0.25 tokens (roughly 4 chars per token) + total_chars = 0 + for msg in messages: + for content in msg.content: + if hasattr(content, "text"): + total_chars += len(content.text) + return total_chars // 4 + + agent_llm.get_token_count.side_effect = mock_token_count + + # Create events that exceed token limit + # Each event has 40 chars = 10 tokens + # 15 events = 150 tokens (exceeds max_tokens of 100) + events: list[Event] = [message_event("A" * 40) for i in range(15)] + view = View.from_events(events) + + # Verify that TOKENS is the condensation reason + reasons = condenser.get_condensation_reasons(view, agent_llm=agent_llm) + assert Reason.TOKENS in reasons + assert Reason.EVENTS not in reasons # Should not trigger on event count + assert Reason.REQUEST not in reasons + + # Condense the view + result = condenser.condense(view, agent_llm=agent_llm) + assert isinstance(result, Condensation) + + # Verify the condenser used its own LLM for summarization + completion_mock = cast(MagicMock, mock_llm.completion) + assert completion_mock.call_count == 1 + + # Verify forgotten events were calculated based on token reduction + assert len(result.forgotten_event_ids) > 0 + + +def test_condense_with_request_and_events_reasons(mock_llm: LLM) -> None: + """Test condensation when both REQUEST and EVENTS reasons are true simultaneously. + + Verifies that the most aggressive condensation (minimum suffix) is chosen. + """ + max_size = 20 + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=max_size, keep_first=keep_first + ) + + # Create events that exceed max_size AND include a condensation request + # 25 events > max_size of 20 (triggers EVENTS) + # Plus a CondensationRequest (triggers REQUEST) + events: list[Event] = [message_event(f"Event {i}") for i in range(25)] + events.append(CondensationRequest()) + view = View.from_events(events) + + # Verify both reasons are present + reasons = condenser.get_condensation_reasons(view, agent_llm=None) + assert Reason.REQUEST in reasons + assert Reason.EVENTS in reasons + assert Reason.TOKENS not in reasons + + # Get the condensation + result = condenser.condense(view) + assert isinstance(result, Condensation) + + # Calculate expected behavior: + # REQUEST: target_size = len(view) // 2 = 25 // 2 = 12 + # suffix_to_keep = 12 - keep_first - 1 = 12 - 2 - 1 = 9 + # EVENTS: target_size = max_size // 2 = 20 // 2 = 10 + # suffix_to_keep = 10 - keep_first - 1 = 10 - 2 - 1 = 7 + # Most aggressive: min(9, 7) = 7 + + # Forgotten events should be from index keep_first to -(7) + # Total events in view = 25 (CondensationRequest is not in view.events) + # Forgotten: events[2:18] = 16 events + expected_forgotten_count = 25 - keep_first - 7 + assert len(result.forgotten_event_ids) == expected_forgotten_count + + +def test_condense_with_request_and_tokens_reasons(mock_llm: LLM) -> None: + """Test condensation when both REQUEST and TOKENS reasons are true simultaneously. + + Verifies that the most aggressive condensation (minimum suffix) is chosen. + """ + max_tokens = 100 + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=1000, max_tokens=max_tokens, keep_first=keep_first + ) + + # Create a separate mock for the agent's LLM with token counting + agent_llm = MagicMock(spec=LLM) + agent_llm.model = "gpt-4" + + # Mock get_token_count to return predictable values + def mock_token_count(messages): + total_chars = 0 + for msg in messages: + for content in msg.content: + if hasattr(content, "text"): + total_chars += len(content.text) + return total_chars // 4 + + agent_llm.get_token_count.side_effect = mock_token_count + + # Create 20 events with 40 chars each = 10 tokens each = 200 total tokens + # This exceeds max_tokens of 100 (triggers TOKENS) + events: list[Event] = [message_event("A" * 40) for i in range(20)] + # Add a CondensationRequest (triggers REQUEST) + events.append(CondensationRequest()) + view = View.from_events(events) + + # Verify both reasons are present + reasons = condenser.get_condensation_reasons(view, agent_llm=agent_llm) + assert Reason.REQUEST in reasons + assert Reason.TOKENS in reasons + assert Reason.EVENTS not in reasons + + # Get the condensation + result = condenser.condense(view, agent_llm=agent_llm) + assert isinstance(result, Condensation) + + # The most aggressive condensation should be chosen (minimum suffix) + assert len(result.forgotten_event_ids) > 0 + + +def test_condense_with_events_and_tokens_reasons(mock_llm: LLM) -> None: + """Test condensation when both EVENTS and TOKENS reasons are true simultaneously. + + Verifies that the most aggressive condensation (minimum suffix) is chosen. + """ + max_size = 15 + max_tokens = 100 + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=max_size, max_tokens=max_tokens, keep_first=keep_first + ) + + # Create a separate mock for the agent's LLM with token counting + agent_llm = MagicMock(spec=LLM) + agent_llm.model = "gpt-4" + + def mock_token_count(messages): + total_chars = 0 + for msg in messages: + for content in msg.content: + if hasattr(content, "text"): + total_chars += len(content.text) + return total_chars // 4 + + agent_llm.get_token_count.side_effect = mock_token_count + + # Create 20 events (exceeds max_size of 15) with 40 chars each + # 20 events * 10 tokens = 200 tokens (exceeds max_tokens of 100) + events: list[Event] = [message_event("A" * 40) for i in range(20)] + view = View.from_events(events) + + # Verify both reasons are present + reasons = condenser.get_condensation_reasons(view, agent_llm=agent_llm) + assert Reason.EVENTS in reasons + assert Reason.TOKENS in reasons + assert Reason.REQUEST not in reasons + + # Get the condensation + result = condenser.condense(view, agent_llm=agent_llm) + assert isinstance(result, Condensation) + + # The most aggressive condensation should be chosen (minimum suffix) + assert len(result.forgotten_event_ids) > 0 + + +def test_condense_with_all_three_reasons(mock_llm: LLM) -> None: + """Test condensation when all three reasons are true simultaneously. + + Verifies that the most aggressive condensation (minimum suffix) is chosen + when REQUEST, EVENTS, and TOKENS all trigger at once. + """ + max_size = 15 + max_tokens = 100 + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=max_size, max_tokens=max_tokens, keep_first=keep_first + ) + + # Create a separate mock for the agent's LLM with token counting + agent_llm = MagicMock(spec=LLM) + agent_llm.model = "gpt-4" + + def mock_token_count(messages): + total_chars = 0 + for msg in messages: + for content in msg.content: + if hasattr(content, "text"): + total_chars += len(content.text) + return total_chars // 4 + + agent_llm.get_token_count.side_effect = mock_token_count + + # Create 20 events (exceeds max_size of 15) with 40 chars each + # 20 events * 10 tokens = 200 tokens (exceeds max_tokens of 100) + events: list[Event] = [message_event("A" * 40) for i in range(20)] + # Add CondensationRequest (triggers REQUEST) + events.append(CondensationRequest()) + view = View.from_events(events) + + # Verify all three reasons are present + reasons = condenser.get_condensation_reasons(view, agent_llm=agent_llm) + assert Reason.REQUEST in reasons + assert Reason.EVENTS in reasons + assert Reason.TOKENS in reasons + + # Get the condensation + result = condenser.condense(view, agent_llm=agent_llm) + assert isinstance(result, Condensation) + + # The most aggressive condensation should be chosen (minimum suffix) + # This means the most events should be forgotten + assert len(result.forgotten_event_ids) > 0 + + # Verify the condenser used its own LLM for summarization + completion_mock = cast(MagicMock, mock_llm.completion) + assert completion_mock.call_count == 1 + + +def test_most_aggressive_condensation_chosen(mock_llm: LLM) -> None: + """Test that the minimum suffix is chosen when multiple reasons provide different + targets. + + This test explicitly verifies the min() logic at line 200 of the condenser. + """ + max_size = 30 # Set high so EVENTS triggers with specific target + keep_first = 2 + condenser = LLMSummarizingCondenser( + llm=mock_llm, max_size=max_size, keep_first=keep_first + ) + + # Create a scenario where REQUEST and EVENTS give different suffix sizes + # 40 events total + events: list[Event] = [message_event(f"Event {i}") for i in range(40)] + events.append(CondensationRequest()) + view = View.from_events(events) + + # Calculate expected suffix lengths: + # REQUEST: target_size = len(view) // 2 = 40 // 2 = 20 + # suffix_to_keep = 20 - keep_first - 1 = 20 - 2 - 1 = 17 + # EVENTS: target_size = max_size // 2 = 30 // 2 = 15 + # suffix_to_keep = 15 - keep_first - 1 = 15 - 2 - 1 = 12 + # Most aggressive: min(17, 12) = 12 + + result = condenser.condense(view) + assert isinstance(result, Condensation) + + # Forgotten events: events[keep_first : -12] = events[2:28] = 26 events + expected_forgotten_count = 40 - keep_first - 12 + assert len(result.forgotten_event_ids) == expected_forgotten_count diff --git a/tests/sdk/context/condenser/test_no_op_condenser.py b/tests/sdk/context/condenser/test_no_op_condenser.py index 3b591c79da..c3ae8b5c43 100644 --- a/tests/sdk/context/condenser/test_no_op_condenser.py +++ b/tests/sdk/context/condenser/test_no_op_condenser.py @@ -1,8 +1,10 @@ +from unittest.mock import MagicMock + from openhands.sdk.context.condenser.no_op_condenser import NoOpCondenser from openhands.sdk.context.view import View from openhands.sdk.event.base import Event from openhands.sdk.event.llm_convertible import MessageEvent -from openhands.sdk.llm import Message, TextContent +from openhands.sdk.llm import LLM, Message, TextContent def message_event(content: str) -> MessageEvent: @@ -26,3 +28,23 @@ def test_noop_condenser() -> None: condensation_result = condenser.condense(view) assert isinstance(condensation_result, View) assert condensation_result.events == events + + +def test_noop_condenser_with_llm() -> None: + """Test that NoOpCondenser works with optional agent_llm parameter.""" + events: list[Event] = [ + message_event("Event 1"), + message_event("Event 2"), + message_event("Event 3"), + ] + + condenser = NoOpCondenser() + view = View.from_events(events) + + # Create a mock LLM + mock_llm = MagicMock(spec=LLM) + + # Condense with agent_llm parameter + condensation_result = condenser.condense(view, agent_llm=mock_llm) + assert isinstance(condensation_result, View) + assert condensation_result.events == events diff --git a/tests/sdk/context/condenser/test_utils.py b/tests/sdk/context/condenser/test_utils.py new file mode 100644 index 0000000000..5d8a08f54b --- /dev/null +++ b/tests/sdk/context/condenser/test_utils.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from openhands.sdk.context.condenser.utils import ( + get_shortest_prefix_above_token_count, + get_suffix_length_for_token_reduction, + get_total_token_count, +) +from openhands.sdk.event.llm_convertible import MessageEvent +from openhands.sdk.llm import LLM, Message, TextContent + + +def message_event(content: str) -> MessageEvent: + """Helper function to create a MessageEvent for testing.""" + return MessageEvent( + llm_message=Message(role="user", content=[TextContent(text=content)]), + source="user", + ) + + +@pytest.fixture +def mock_llm() -> LLM: + """Create a mock LLM with token counting capability.""" + mock_llm = MagicMock(spec=LLM) + mock_llm.model = "test-model" + + # Mock get_token_count to return predictable values based on message content length + def mock_token_count(messages): + # Simple heuristic: count characters in all text content + # Each character = 0.25 tokens (roughly 4 chars per token) + total_chars = 0 + for msg in messages: + for content in msg.content: + if hasattr(content, "text"): + total_chars += len(content.text) + return total_chars // 4 + + mock_llm.get_token_count.side_effect = mock_token_count + + return mock_llm + + +class TestGetTotalTokenCount: + """Tests for get_total_token_count function.""" + + def test_empty_events(self, mock_llm: LLM): + """Test with empty event list.""" + events = [] + token_count = get_total_token_count(events, mock_llm) + assert token_count == 0 + + def test_single_event(self, mock_llm: LLM): + """Test with a single event.""" + events = [message_event("Hello world")] # 11 chars -> 2 tokens + token_count = get_total_token_count(events, mock_llm) + assert token_count == 2 + + def test_multiple_events(self, mock_llm: LLM): + """Test with multiple events.""" + events = [ + message_event("Hello"), # 5 chars -> 1 token + message_event("World"), # 5 chars -> 1 token + message_event("Test message"), # 12 chars -> 3 tokens + ] + token_count = get_total_token_count(events, mock_llm) + assert token_count == 5 # (5 + 5 + 12) // 4 = 5 + + def test_events_converted_to_messages(self, mock_llm: LLM): + """Test that events are properly converted to messages.""" + events = [message_event("Test")] + get_total_token_count(events, mock_llm) + + # Verify get_token_count was called + assert mock_llm.get_token_count.called # type: ignore + # Verify it was called with a list of messages + call_args = mock_llm.get_token_count.call_args[0][0] # type: ignore + assert isinstance(call_args, list) + assert all(isinstance(msg, Message) for msg in call_args) + + +class TestGetShortestPrefixAboveTokenCount: + """Tests for get_shortest_prefix_above_token_count function.""" + + def test_empty_events(self, mock_llm: LLM): + """Test with empty event list.""" + events = [] + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 10) + assert prefix_length == 0 + + def test_no_prefix_exceeds_token_count(self, mock_llm: LLM): + """Test when total tokens don't exceed the target.""" + events = [ + message_event("Hi"), # 2 chars -> 0 tokens + message_event("Bye"), # 3 chars -> 0 tokens + ] + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 100) + assert prefix_length == len(events) + + def test_single_event_exceeds(self, mock_llm: LLM): + """Test when first event alone exceeds the token count.""" + events = [ + message_event("A" * 100), # 100 chars -> 25 tokens + message_event("B" * 100), # 100 chars -> 25 tokens + ] + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 20) + assert prefix_length == 1 + + def test_multiple_events_needed(self, mock_llm: LLM): + """Test when multiple events are needed to exceed token count.""" + events = [ + message_event("A" * 20), # 20 chars -> 5 tokens + message_event("B" * 20), # 20 chars -> 5 tokens + message_event("C" * 20), # 20 chars -> 5 tokens + message_event("D" * 20), # 20 chars -> 5 tokens + ] + # Need prefix of 3 events to exceed 10 tokens (15 > 10) + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 10) + assert prefix_length == 3 + + def test_exact_boundary(self, mock_llm: LLM): + """Test behavior at exact token count boundary.""" + events = [ + message_event("A" * 40), # 40 chars -> 10 tokens + message_event("B" * 40), # 40 chars -> 10 tokens + ] + # 10 tokens is not > 10, need 2 events for 20 tokens + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 10) + assert prefix_length == 2 + + def test_all_events_needed(self, mock_llm: LLM): + """Test when all events together just exceed the token count.""" + events = [ + message_event("A" * 16), # 16 chars -> 4 tokens + message_event("B" * 16), # 16 chars -> 4 tokens + message_event("C" * 16), # 16 chars -> 4 tokens + ] + # Total 12 tokens, need all 3 to exceed 10 + prefix_length = get_shortest_prefix_above_token_count(events, mock_llm, 10) + assert prefix_length == 3 + + +class TestGetSuffixLengthForTokenReduction: + """Tests for get_suffix_length_for_token_reduction function.""" + + def test_empty_events(self, mock_llm: LLM): + """Test with empty event list.""" + events = [] + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 10) + assert suffix_length == 0 + + def test_zero_token_reduction(self, mock_llm: LLM): + """Test with zero token reduction requested.""" + events = [ + message_event("Test"), + message_event("Message"), + ] + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 0) + assert suffix_length == len(events) + + def test_negative_token_reduction(self, mock_llm: LLM): + """Test with negative token reduction (edge case).""" + events = [ + message_event("Test"), + message_event("Message"), + ] + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, -10) + assert suffix_length == len(events) + + def test_small_reduction(self, mock_llm: LLM): + """Test with small token reduction that removes few events.""" + events = [ + message_event("A" * 40), # 40 chars -> 10 tokens + message_event("B" * 40), # 40 chars -> 10 tokens + message_event("C" * 40), # 40 chars -> 10 tokens + message_event("D" * 40), # 40 chars -> 10 tokens + ] + # Total 40 tokens, reduce by 15 means keep suffix after removing 1 event (10 + # tokens). Actually need to remove 2 events (20 tokens) to exceed 15 token + # reduction + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 15) + assert suffix_length == 2 # Keep last 2 events + + def test_large_reduction(self, mock_llm: LLM): + """Test with large token reduction that removes most events.""" + events = [ + message_event("A" * 20), # 20 chars -> 5 tokens + message_event("B" * 20), # 20 chars -> 5 tokens + message_event("C" * 20), # 20 chars -> 5 tokens + message_event("D" * 20), # 20 chars -> 5 tokens + ] + # Total 20 tokens, reduce by 18 tokens means remove 4 events (20 tokens) + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 18) + assert suffix_length == 0 # Keep nothing + + def test_exact_reduction(self, mock_llm: LLM): + """Test with exact token reduction matching some events.""" + events = [ + message_event("A" * 40), # 40 chars -> 10 tokens + message_event("B" * 40), # 40 chars -> 10 tokens + message_event("C" * 40), # 40 chars -> 10 tokens + ] + # Total 30 tokens, reduce by exactly 10 tokens + # Need to remove 2 events (20 tokens) to exceed 10 token reduction + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 10) + assert suffix_length == 1 # Keep last 1 event + + def test_impossible_reduction(self, mock_llm: LLM): + """Test when requested reduction exceeds total tokens.""" + events = [ + message_event("Hi"), # 2 chars -> 0 tokens + message_event("Bye"), # 3 chars -> 0 tokens + ] + # Total ~0 tokens, but asking to reduce by 100 + suffix_length = get_suffix_length_for_token_reduction(events, mock_llm, 100) + assert suffix_length == 0 # Can't keep anything + + def test_consistency_with_prefix_function(self, mock_llm: LLM): + """Test that suffix calculation is consistent with prefix calculation.""" + events = [ + message_event("A" * 40), # 40 chars -> 10 tokens + message_event("B" * 40), # 40 chars -> 10 tokens + message_event("C" * 40), # 40 chars -> 10 tokens + message_event("D" * 40), # 40 chars -> 10 tokens + ] + token_reduction = 25 + + suffix_length = get_suffix_length_for_token_reduction( + events, mock_llm, token_reduction + ) + prefix_length = get_shortest_prefix_above_token_count( + events, mock_llm, token_reduction + ) + + # Suffix + prefix should equal total length + assert suffix_length + prefix_length == len(events)