diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 287507258f74..82834607adee 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -448,8 +448,8 @@ async fn process_message_streaming( // Create a user message let user_message = GooseMessage::user().with_text(content.clone()); - // Get existing messages from session and add the new user message - let mut messages = { + // Messages will be auto-compacted in agent.reply() if needed + let messages = { let mut session_msgs = session_messages.lock().await; session_msgs.push(user_message.clone()); session_msgs.clone() @@ -618,7 +618,10 @@ async fn process_message_streaming( // TODO: Implement proper UI for context handling let (summarized_messages, _) = agent.summarize_context(&messages).await?; - messages = summarized_messages; + { + let mut session_msgs = session_messages.lock().await; + *session_msgs = summarized_messages; + } } _ => { // Handle other message types as needed @@ -626,6 +629,30 @@ async fn process_message_streaming( } } } + Ok(AgentEvent::HistoryReplaced(new_messages)) => { + // Replace the session's message history with the compacted messages + { + let mut session_msgs = session_messages.lock().await; + *session_msgs = new_messages; + } + + // Persist the updated messages to the JSONL file + let current_messages = { + let session_msgs = session_messages.lock().await; + session_msgs.clone() + }; + + if let Err(e) = session::persist_messages( + &session_file, + ¤t_messages, + None, // No provider needed for persisting + working_dir.clone(), + ) + .await + { + error!("Failed to persist compacted messages: {}", e); + } + } Ok(AgentEvent::McpNotification(_notification)) => { // Handle MCP notifications if needed // For now, we'll just log them diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index e9cf80a233b3..94b655d22e17 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -846,6 +846,7 @@ impl Session { } async fn process_agent_response(&mut self, interactive: bool) -> Result<()> { + // Messages will be auto-compacted in agent.reply() if needed let cancel_token = CancellationToken::new(); let cancel_token_clone = cancel_token.clone(); @@ -1140,6 +1141,25 @@ impl Session { _ => (), } } + Some(Ok(AgentEvent::HistoryReplaced(new_messages))) => { + // Replace the session's message history with the compacted messages + self.messages = new_messages; + + // Persist the updated messages to the session file + if let Some(session_file) = &self.session_file { + let provider = self.agent.provider().await.ok(); + let working_dir = std::env::current_dir().ok(); + if let Err(e) = session::persist_messages_with_schedule_id( + session_file, + &self.messages, + provider, + self.scheduled_job_id.clone(), + working_dir, + ).await { + eprintln!("Failed to persist compacted messages: {}", e); + } + } + } Some(Ok(AgentEvent::ModelChange { model, mode })) => { // Log model change if in debug mode if self.debug { diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index d3ac7208e0a0..60cec0fadde1 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -159,8 +159,15 @@ async fn reply_handler( retry_config: None, }; + // Messages will be auto-compacted in agent.reply() if needed + let messages_to_process = messages.clone(); + let mut stream = match agent - .reply(&messages, Some(session_config), Some(task_cancel.clone())) + .reply( + &messages_to_process, + Some(session_config), + Some(task_cancel.clone()), + ) .await { Ok(stream) => stream, @@ -215,6 +222,12 @@ async fn reply_handler( break; } } + Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { + // Replace the message history with the compacted messages + all_messages = new_messages; + // Note: We don't send this as a stream event since it's an internal operation + // The client will see the compaction notification message that was sent before this event + } Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { tracing::error!("Error sending model change through channel: {}", e); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 5a9e7e4a7d39..c106f43f6be1 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -33,6 +33,7 @@ use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; +use crate::context_mgmt::auto_compact; use crate::message::{push_message, Message, ToolRequest}; use crate::permission::permission_judge::{check_tool_permissions, PermissionCheckResult}; use crate::permission::PermissionConfirmation; @@ -102,6 +103,7 @@ pub enum AgentEvent { Message(Message), McpNotification((String, ServerNotification)), ModelChange { model: String, mode: String }, + HistoryReplaced(Vec), } impl Default for Agent { @@ -746,6 +748,36 @@ impl Agent { } } + /// Handle auto-compaction logic and return compacted messages if needed + async fn handle_auto_compaction( + &self, + messages: &[Message], + ) -> Result, String)>> { + let compact_result = auto_compact::check_and_compact_messages(self, messages, None).await?; + + if compact_result.compacted { + let compacted_messages = compact_result.messages; + + // Create compaction notification message + let compaction_msg = if let (Some(before), Some(after)) = + (compact_result.tokens_before, compact_result.tokens_after) + { + format!( + "Auto-compacted context: {} → {} tokens ({:.0}% reduction)\n\n", + before, + after, + (1.0 - (after as f64 / before as f64)) * 100.0 + ) + } else { + "Auto-compacted context to reduce token usage\n\n".to_string() + }; + + return Ok(Some((compacted_messages, compaction_msg))); + } + + Ok(None) + } + #[instrument(skip(self, unfixed_messages, session), fields(user_message))] pub async fn reply( &self, @@ -753,9 +785,44 @@ impl Agent { session: Option, cancel_token: Option, ) -> Result>> { - let context = self - .prepare_reply_context(unfixed_messages, &session) - .await?; + // Handle auto-compaction before processing + let (messages, compaction_msg) = match self.handle_auto_compaction(unfixed_messages).await? + { + Some((compacted_messages, msg)) => (compacted_messages, Some(msg)), + None => { + let context = self + .prepare_reply_context(unfixed_messages, &session) + .await?; + (context.messages, None) + } + }; + + // If we compacted, yield the compaction message and history replacement event + if let Some(compaction_msg) = compaction_msg { + return Ok(Box::pin(async_stream::try_stream! { + yield AgentEvent::Message(Message::assistant().with_text(compaction_msg)); + yield AgentEvent::HistoryReplaced(messages.clone()); + + // Continue with normal reply processing using compacted messages + let mut reply_stream = self.reply_internal(&messages, session, cancel_token).await?; + while let Some(event) = reply_stream.next().await { + yield event?; + } + })); + } + + // No compaction needed, proceed with normal processing + self.reply_internal(&messages, session, cancel_token).await + } + + /// Main reply method that handles the actual agent processing + async fn reply_internal( + &self, + messages: &[Message], + session: Option, + cancel_token: Option, + ) -> Result>> { + let context = self.prepare_reply_context(messages, &session).await?; let ReplyContext { mut messages, mut tools, @@ -765,7 +832,6 @@ impl Agent { initial_messages, config, } = context; - let reply_span = tracing::Span::current(); self.reset_retry_attempts().await; diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs new file mode 100644 index 000000000000..45268f4e8db3 --- /dev/null +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -0,0 +1,533 @@ +use crate::{ + agents::Agent, + config::Config, + context_mgmt::{estimate_target_context_limit, get_messages_token_counts_async}, + message::Message, + token_counter::create_async_token_counter, +}; +use anyhow::Result; +use tracing::{debug, info}; + +/// Result of auto-compaction check +#[derive(Debug)] +pub struct AutoCompactResult { + /// Whether compaction was performed + pub compacted: bool, + /// The messages after potential compaction + pub messages: Vec, + /// Token count before compaction (if compaction occurred) + pub tokens_before: Option, + /// Token count after compaction (if compaction occurred) + pub tokens_after: Option, +} + +/// Result of checking if compaction is needed +#[derive(Debug)] +pub struct CompactionCheckResult { + /// Whether compaction is needed + pub needs_compaction: bool, + /// Current token count + pub current_tokens: usize, + /// Context limit being used + pub context_limit: usize, + /// Current usage ratio (0.0 to 1.0) + pub usage_ratio: f64, + /// Remaining tokens before compaction threshold + pub remaining_tokens: usize, + /// Percentage until compaction threshold (0.0 to 100.0) + pub percentage_until_compaction: f64, +} + +/// Check if messages need compaction without performing the compaction +/// +/// This function analyzes the current token usage and returns detailed information +/// about whether compaction is needed and how close we are to the threshold. +/// +/// # Arguments +/// * `agent` - The agent to use for context management +/// * `messages` - The current message history +/// * `threshold_override` - Optional threshold override (defaults to GOOSE_AUTO_COMPACT_THRESHOLD config) +/// +/// # Returns +/// * `CompactionCheckResult` containing detailed information about compaction needs +pub async fn check_compaction_needed( + agent: &Agent, + messages: &[Message], + threshold_override: Option, +) -> Result { + // Get threshold from config or use override + let config = Config::global(); + let threshold = threshold_override.unwrap_or_else(|| { + config + .get_param::("GOOSE_AUTO_COMPACT_THRESHOLD") + .unwrap_or(0.3) // Default to 30% + }); + + // Get provider and token counter + let provider = agent.provider().await?; + let token_counter = create_async_token_counter() + .await + .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; + + // Calculate current token usage + let token_counts = get_messages_token_counts_async(&token_counter, messages); + let current_tokens: usize = token_counts.iter().sum(); + let context_limit = estimate_target_context_limit(provider); + + // Calculate usage ratio + let usage_ratio = current_tokens as f64 / context_limit as f64; + + // Calculate threshold token count and remaining tokens + let threshold_tokens = (context_limit as f64 * threshold) as usize; + let remaining_tokens = threshold_tokens.saturating_sub(current_tokens); + + // Calculate percentage until compaction (how much more we can use before hitting threshold) + let percentage_until_compaction = if usage_ratio < threshold { + (threshold - usage_ratio) * 100.0 + } else { + 0.0 + }; + + // Check if compaction is needed (disabled if threshold is invalid) + let needs_compaction = if threshold <= 0.0 || threshold >= 1.0 { + false + } else { + usage_ratio > threshold + }; + + debug!( + "Compaction check: {} / {} tokens ({:.1}%), threshold: {:.1}%, needs compaction: {}", + current_tokens, + context_limit, + usage_ratio * 100.0, + threshold * 100.0, + needs_compaction + ); + + Ok(CompactionCheckResult { + needs_compaction, + current_tokens, + context_limit, + usage_ratio, + remaining_tokens, + percentage_until_compaction, + }) +} + +/// Perform compaction on messages +/// +/// This function performs the actual compaction using the agent's summarization +/// capabilities. It assumes compaction is needed and should be called after +/// `check_compaction_needed` confirms it's necessary. +/// +/// # Arguments +/// * `agent` - The agent to use for context management +/// * `messages` - The current message history to compact +/// +/// # Returns +/// * Tuple of (compacted_messages, tokens_before, tokens_after) +pub async fn perform_compaction( + agent: &Agent, + messages: &[Message], +) -> Result<(Vec, usize, usize)> { + // Get token counter to measure before/after + let token_counter = create_async_token_counter() + .await + .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; + + // Calculate tokens before compaction + let token_counts_before = get_messages_token_counts_async(&token_counter, messages); + let tokens_before: usize = token_counts_before.iter().sum(); + + info!("Performing compaction on {} tokens", tokens_before); + + // Perform compaction + let (compacted_messages, compacted_token_counts) = agent.summarize_context(messages).await?; + let tokens_after: usize = compacted_token_counts.iter().sum(); + + info!( + "Compaction complete: {} tokens -> {} tokens ({:.1}% reduction)", + tokens_before, + tokens_after, + (1.0 - (tokens_after as f64 / tokens_before as f64)) * 100.0 + ); + + Ok((compacted_messages, tokens_before, tokens_after)) +} + +/// Check if messages need compaction and compact them if necessary +/// +/// This is a convenience wrapper function that combines checking and compaction. +/// If the most recent message is a user message, it will be preserved by removing it +/// before compaction and adding it back afterwards. +/// +/// # Arguments +/// * `agent` - The agent to use for context management +/// * `messages` - The current message history +/// * `threshold_override` - Optional threshold override (defaults to GOOSE_AUTO_COMPACT_THRESHOLD config) +/// +/// # Returns +/// * `AutoCompactResult` containing the potentially compacted messages and metadata +pub async fn check_and_compact_messages( + agent: &Agent, + messages: &[Message], + threshold_override: Option, +) -> Result { + // First check if compaction is needed + let check_result = check_compaction_needed(agent, messages, threshold_override).await?; + + // If no compaction is needed, return early + if !check_result.needs_compaction { + debug!( + "No compaction needed (usage: {:.1}% <= {:.1}% threshold)", + check_result.usage_ratio * 100.0, + check_result.percentage_until_compaction + ); + return Ok(AutoCompactResult { + compacted: false, + messages: messages.to_vec(), + tokens_before: None, + tokens_after: None, + }); + } + + info!( + "Auto-compacting messages (usage: {:.1}%)", + check_result.usage_ratio * 100.0 + ); + + // Check if the most recent message is a user message + let (messages_to_compact, preserved_user_message) = if let Some(last_message) = messages.last() + { + if matches!(last_message.role, rmcp::model::Role::User) { + // Remove the last user message before auto-compaction + (&messages[..messages.len() - 1], Some(last_message.clone())) + } else { + (messages, None) + } + } else { + (messages, None) + }; + + // Perform the compaction on messages excluding the preserved user message + let (mut compacted_messages, tokens_before, tokens_after) = + perform_compaction(agent, messages_to_compact).await?; + + // Add back the preserved user message if it exists + if let Some(user_message) = preserved_user_message { + compacted_messages.push(user_message); + } + + Ok(AutoCompactResult { + compacted: true, + messages: compacted_messages, + tokens_before: Some(tokens_before), + tokens_after: Some(tokens_after), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + agents::Agent, + message::{Message, MessageContent}, + model::ModelConfig, + providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}, + providers::errors::ProviderError, + }; + use chrono::Utc; + use rmcp::model::{AnnotateAble, RawTextContent, Role, Tool}; + use std::sync::Arc; + + #[derive(Clone)] + struct MockProvider { + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for MockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + // Return a short summary message + Ok(( + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text( + RawTextContent { + text: "Summary of conversation".to_string(), + } + .no_annotation(), + )], + ), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } + } + + fn create_test_message(text: &str) -> Message { + Message::new( + Role::User, + Utc::now().timestamp(), + vec![MessageContent::text(text.to_string())], + ) + } + + #[tokio::test] + async fn test_check_compaction_needed() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(100_000.into()), + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + // Create small messages that won't trigger compaction + let messages = vec![create_test_message("Hello"), create_test_message("World")]; + + let result = check_compaction_needed(&agent, &messages, Some(0.3)) + .await + .unwrap(); + + assert!(!result.needs_compaction); + assert!(result.current_tokens > 0); + assert!(result.context_limit > 0); + assert!(result.usage_ratio < 0.3); + assert!(result.remaining_tokens > 0); + assert!(result.percentage_until_compaction > 0.0); + } + + #[tokio::test] + async fn test_check_compaction_needed_disabled() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(100_000.into()), + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + let messages = vec![create_test_message("Hello")]; + + // Test with threshold 0 (disabled) + let result = check_compaction_needed(&agent, &messages, Some(0.0)) + .await + .unwrap(); + + assert!(!result.needs_compaction); + + // Test with threshold 1.0 (disabled) + let result = check_compaction_needed(&agent, &messages, Some(1.0)) + .await + .unwrap(); + + assert!(!result.needs_compaction); + } + + #[tokio::test] + async fn test_perform_compaction() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(50_000.into()), + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + // Create some messages to compact + let messages = vec![ + create_test_message("First message"), + create_test_message("Second message"), + create_test_message("Third message"), + ]; + + let (compacted_messages, tokens_before, tokens_after) = + perform_compaction(&agent, &messages).await.unwrap(); + + assert!(tokens_before > 0); + assert!(tokens_after > 0); + // Note: The mock provider returns a fixed summary, which might not always be smaller + // In real usage, compaction should reduce tokens, but for testing we just verify it works + assert!(!compacted_messages.is_empty()); + } + + #[tokio::test] + async fn test_auto_compact_disabled() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(10_000.into()), + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + let messages = vec![create_test_message("Hello"), create_test_message("World")]; + + // Test with threshold 0 (disabled) + let result = check_and_compact_messages(&agent, &messages, Some(0.0)) + .await + .unwrap(); + + assert!(!result.compacted); + assert_eq!(result.messages.len(), messages.len()); + assert!(result.tokens_before.is_none()); + assert!(result.tokens_after.is_none()); + + // Test with threshold 1.0 (disabled) + let result = check_and_compact_messages(&agent, &messages, Some(1.0)) + .await + .unwrap(); + + assert!(!result.compacted); + } + + #[tokio::test] + async fn test_auto_compact_below_threshold() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(100_000.into()), // Increased to ensure overhead doesn't dominate + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + // Create small messages that won't trigger compaction + let messages = vec![create_test_message("Hello"), create_test_message("World")]; + + let result = check_and_compact_messages(&agent, &messages, Some(0.3)) + .await + .unwrap(); + + assert!(!result.compacted); + assert_eq!(result.messages.len(), messages.len()); + } + + #[tokio::test] + async fn test_auto_compact_above_threshold() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(50_000.into()), // Realistic context limit that won't underflow + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + // Create messages that will exceed 30% of the context limit + // With 50k context limit, after overhead we have ~27k usable tokens + // 30% of that is ~8.1k tokens, so we need messages that exceed that + let mut messages = Vec::new(); + + // Create longer messages with more content to reach the threshold + for i in 0..200 { + messages.push(create_test_message(&format!( + "This is message number {} with significantly more content to increase token count. \ + We need to ensure that our total token usage exceeds 30% of the available context \ + limit after accounting for system prompt and tools overhead. This message contains \ + multiple sentences to increase the token count substantially.", + i + ))); + } + + let result = check_and_compact_messages(&agent, &messages, Some(0.3)) + .await + .unwrap(); + + assert!(result.compacted); + assert!(result.tokens_before.is_some()); + assert!(result.tokens_after.is_some()); + + // Should have fewer tokens after compaction + if let (Some(before), Some(after)) = (result.tokens_before, result.tokens_after) { + assert!( + after < before, + "Token count should decrease after compaction" + ); + } + + // Should have fewer messages (summarized) + assert!(result.messages.len() <= messages.len()); + } + + #[tokio::test] + async fn test_auto_compact_respects_config() { + let mock_provider = Arc::new(MockProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(30_000.into()), // Smaller context limit to make threshold easier to hit + }); + + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + + // Create enough messages to trigger compaction with low threshold + let mut messages = Vec::new(); + // With 30k context limit, after overhead we have ~27k usable tokens + // 10% of 27k = 2.7k tokens, so we need messages that exceed that + for i in 0..200 { + messages.push(create_test_message(&format!( + "Message {} with enough content to ensure we exceed 10% of the context limit. \ + Adding more content to increase token count substantially. This message contains \ + multiple sentences to increase the token count. We need to ensure that our total \ + token usage exceeds 10% of the available context limit after accounting for \ + system prompt and tools overhead.", + i + ))); + } + + // Set config value + let config = Config::global(); + config + .set_param("GOOSE_AUTO_COMPACT_THRESHOLD", serde_json::Value::from(0.1)) + .unwrap(); + + // Should use config value when no override provided + let result = check_and_compact_messages(&agent, &messages, None) + .await + .unwrap(); + + // Debug info if not compacted + if !result.compacted { + let provider = agent.provider().await.unwrap(); + let token_counter = create_async_token_counter().await.unwrap(); + let token_counts = get_messages_token_counts_async(&token_counter, &messages); + let total_tokens: usize = token_counts.iter().sum(); + let context_limit = estimate_target_context_limit(provider); + let usage_ratio = total_tokens as f64 / context_limit as f64; + + eprintln!( + "Config test not compacted - tokens: {} / {} ({:.1}%)", + total_tokens, + context_limit, + usage_ratio * 100.0 + ); + } + + // With such a low threshold (10%), it should compact + assert!(result.compacted); + + // Clean up config + config + .set_param("GOOSE_AUTO_COMPACT_THRESHOLD", serde_json::Value::from(0.3)) + .unwrap(); + } +} diff --git a/crates/goose/src/context_mgmt/common.rs b/crates/goose/src/context_mgmt/common.rs index 3f9054361b95..3883386a2e44 100644 --- a/crates/goose/src/context_mgmt/common.rs +++ b/crates/goose/src/context_mgmt/common.rs @@ -19,8 +19,14 @@ pub fn estimate_target_context_limit(provider: Arc) -> usize { // Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude) let target_limit = (model_context_limit as f32 * ESTIMATE_FACTOR) as usize; - // subtract out overhead for system prompt and tools - target_limit - (SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD) + // subtract out overhead for system prompt and tools, but ensure we don't go negative + let overhead = SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD; + if target_limit > overhead { + target_limit - overhead + } else { + // If overhead is larger than target limit, return a minimal usable limit + std::cmp::max(target_limit / 2, 1000) + } } pub fn get_messages_token_counts(token_counter: &TokenCounter, messages: &[Message]) -> Vec { diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index 838e27fece54..00d11d6b871b 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -1,3 +1,4 @@ +pub mod auto_compact; mod common; pub mod summarize; pub mod truncate; diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index ef798a2d2df3..c7a92fa2954a 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -1,14 +1,21 @@ use super::common::{get_messages_token_counts, get_messages_token_counts_async}; -use crate::message::{Message, MessageContent}; +use crate::message::Message; +use crate::prompt_template::render_global_file; use crate::providers::base::Provider; use crate::token_counter::{AsyncTokenCounter, TokenCounter}; use anyhow::Result; use rmcp::model::Role; +use serde::Serialize; use std::sync::Arc; // Constants for the summarization prompt and a follow-up user message. const SUMMARY_PROMPT: &str = "You are good at summarizing conversations"; +#[derive(Serialize)] +struct SummarizeContext { + messages: String, +} + /// Summarize the combined messages from the accumulated summary and the current chunk. /// /// This method builds the summarization request, sends it to the provider, and returns the summarized response. @@ -43,62 +50,54 @@ async fn summarize_combined_messages( Ok(vec![response]) } -/// Preprocesses the messages to handle edge cases involving tool responses. -/// -/// This function separates messages into two groups: -/// 1. Messages to be summarized (`preprocessed_messages`) -/// 2. Messages to be temporarily removed (`removed_messages`), which include: -/// - The last tool response message. -/// - The corresponding tool request message that immediately precedes the last tool response message (if present). -/// -/// The function only considers the last tool response message and its pair for removal. -fn preprocess_messages(messages: &[Message]) -> (Vec, Vec) { - let mut preprocessed_messages = messages.to_owned(); - let mut removed_messages = Vec::new(); - - if let Some((last_index, last_message)) = messages.iter().enumerate().rev().find(|(_, m)| { - m.content - .iter() - .any(|c| matches!(c, MessageContent::ToolResponse(_))) - }) { - // Check for the corresponding tool request message - if last_index > 0 { - if let Some(previous_message) = messages.get(last_index - 1) { - if previous_message - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_))) - { - // Add the tool request message to removed_messages - removed_messages.push(previous_message.clone()); - } - } - } - // Add the last tool response message to removed_messages - removed_messages.push(last_message.clone()); +// Summarization steps: +// Using a single tailored prompt, summarize the entire conversation history. +pub async fn summarize_messages_oneshot( + provider: Arc, + messages: &[Message], + token_counter: &TokenCounter, + _context_limit: usize, +) -> Result<(Vec, Vec), anyhow::Error> { + if messages.is_empty() { + // If no messages to summarize, return empty + return Ok((vec![], vec![])); + } - // Calculate the correct start index for removal - let start_index = last_index + 1 - removed_messages.len(); + // Format all messages as a single string for the summarization prompt + let messages_text = messages + .iter() + .map(|msg| format!("{:?}", msg)) + .collect::>() + .join("\n\n"); - // Remove the tool response and its paired tool request from preprocessed_messages - preprocessed_messages.drain(start_index..=last_index); - } + let context = SummarizeContext { + messages: messages_text, + }; - (preprocessed_messages, removed_messages) -} + // Render the one-shot summarization prompt + let system_prompt = render_global_file("summarize_oneshot.md", &context)?; -/// Reinserts removed messages into the summarized output. -/// -/// This function appends messages that were temporarily removed during preprocessing -/// back into the summarized message list. This ensures that important context, -/// such as tool responses, is not lost. -fn reintegrate_removed_messages( - summarized_messages: &[Message], - removed_messages: &[Message], -) -> Vec { - let mut final_messages = summarized_messages.to_owned(); - final_messages.extend_from_slice(removed_messages); - final_messages + // Create a simple user message requesting summarization + let user_message = Message::user() + .with_text("Please summarize the conversation history provided in the system prompt."); + let summarization_request = vec![user_message]; + + // Send the request to the provider and fetch the response. + let mut response = provider + .complete(&system_prompt, &summarization_request, &[]) + .await? + .0; + + // Set role to user as it will be used in following conversation as user content. + response.role = Role::User; + + // Return just the summary without any tool response preservation + let final_summary = vec![response]; + + Ok(( + final_summary.clone(), + get_messages_token_counts(token_counter, &final_summary), + )) } // Summarization steps: @@ -107,7 +106,7 @@ fn reintegrate_removed_messages( // a. Combine it with the previous summary (or leave blank for the first iteration). // b. Summarize the combined text, focusing on extracting only the information we need. // 3. Generate a final summary using a tailored prompt. -pub async fn summarize_messages( +pub async fn summarize_messages_chunked( provider: Arc, messages: &[Message], token_counter: &TokenCounter, @@ -117,17 +116,14 @@ pub async fn summarize_messages( let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); let mut accumulated_summary = Vec::new(); - // Preprocess messages to handle tool response edge case. - let (preprocessed_messages, removed_messages) = preprocess_messages(messages); - // Get token counts for each message. - let token_counts = get_messages_token_counts(token_counter, &preprocessed_messages); + let token_counts = get_messages_token_counts(token_counter, messages); // Tokenize and break messages into chunks. let mut current_chunk: Vec = Vec::new(); let mut current_chunk_tokens = 0; - for (message, message_tokens) in preprocessed_messages.iter().zip(token_counts.iter()) { + for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { // Summarize the current chunk with the accumulated summary. accumulated_summary = @@ -150,15 +146,61 @@ pub async fn summarize_messages( summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk).await?; } - // Add back removed messages. - let final_summary = reintegrate_removed_messages(&accumulated_summary, &removed_messages); - + // Return just the summary without any tool response preservation Ok(( - final_summary.clone(), - get_messages_token_counts(token_counter, &final_summary), + accumulated_summary.clone(), + get_messages_token_counts(token_counter, &accumulated_summary), )) } +/// Main summarization function that chooses the best algorithm based on context size. +/// +/// This function will: +/// 1. First try the one-shot summarization if there's enough context window available +/// 2. Fall back to the chunked approach if the one-shot fails or if context is too limited +/// 3. Choose the algorithm based on absolute token requirements rather than percentages +pub async fn summarize_messages( + provider: Arc, + messages: &[Message], + token_counter: &TokenCounter, + context_limit: usize, +) -> Result<(Vec, Vec), anyhow::Error> { + // Calculate total tokens in messages + let total_tokens: usize = get_messages_token_counts(token_counter, messages) + .iter() + .sum(); + + // Calculate absolute token requirements (future-proof for large context models) + let system_prompt_overhead = 1000; // Conservative estimate for the summarization prompt + let response_overhead = 4000; // Generous buffer for response generation + let safety_buffer = 1000; // Small safety margin for tokenization variations + let total_required = total_tokens + system_prompt_overhead + response_overhead + safety_buffer; + + // Use one-shot if we have enough absolute space (no percentage-based limits) + if total_required <= context_limit { + match summarize_messages_oneshot( + Arc::clone(&provider), + messages, + token_counter, + context_limit, + ) + .await + { + Ok(result) => return Ok(result), + Err(e) => { + // Log the error but continue to fallback + tracing::warn!( + "One-shot summarization failed, falling back to chunked approach: {}", + e + ); + } + } + } + + // Fall back to the chunked approach + summarize_messages_chunked(provider, messages, token_counter, context_limit).await +} + /// Async version using AsyncTokenCounter for better performance pub async fn summarize_messages_async( provider: Arc, @@ -170,17 +212,14 @@ pub async fn summarize_messages_async( let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); let mut accumulated_summary = Vec::new(); - // Preprocess messages to handle tool response edge case. - let (preprocessed_messages, removed_messages) = preprocess_messages(messages); - // Get token counts for each message. - let token_counts = get_messages_token_counts_async(token_counter, &preprocessed_messages); + let token_counts = get_messages_token_counts_async(token_counter, messages); // Tokenize and break messages into chunks. let mut current_chunk: Vec = Vec::new(); let mut current_chunk_tokens = 0; - for (message, message_tokens) in preprocessed_messages.iter().zip(token_counts.iter()) { + for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { // Summarize the current chunk with the accumulated summary. accumulated_summary = @@ -203,12 +242,10 @@ pub async fn summarize_messages_async( summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk).await?; } - // Add back removed messages. - let final_summary = reintegrate_removed_messages(&accumulated_summary, &removed_messages); - + // Return just the summary without any tool response preservation Ok(( - final_summary.clone(), - get_messages_token_counts_async(token_counter, &final_summary), + accumulated_summary.clone(), + get_messages_token_counts_async(token_counter, &accumulated_summary), )) } @@ -220,11 +257,9 @@ mod tests { use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; use chrono::Utc; - use mcp_core::ToolCall; use rmcp::model::Role; use rmcp::model::Tool; - use rmcp::model::{AnnotateAble, Content, RawTextContent}; - use serde_json::json; + use rmcp::model::{AnnotateAble, RawTextContent}; use std::sync::Arc; #[derive(Clone)] @@ -265,8 +300,7 @@ mod tests { } fn create_mock_provider() -> Result> { - let mock_model_config = - ModelConfig::new_or_fail("test-model").with_context_limit(200_000.into()); + let mock_model_config = ModelConfig::new("test-model")?.with_context_limit(200_000.into()); Ok(Arc::new(MockProvider { model_config: mock_model_config, @@ -285,30 +319,11 @@ mod tests { Message::new(role, 0, vec![MessageContent::text(text.to_string())]) } - fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message { - Message::new( - Role::Assistant, - 0, - vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], - ) - } - - fn set_up_tool_response_message(id: &str, tool_response: Vec) -> Message { - Message::new( - Role::User, - 0, - vec![MessageContent::tool_response( - id.to_string(), - Ok(tool_response), - )], - ) - } - #[tokio::test] async fn test_summarize_messages_single_chunk() { let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); - let context_limit = 100; // Set a high enough limit to avoid chunking. + let context_limit = 10_000; // Higher limit to avoid underflow let messages = create_test_messages(); let result = summarize_messages( @@ -344,7 +359,7 @@ mod tests { async fn test_summarize_messages_multiple_chunks() { let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); - let context_limit = 30; + let context_limit = 10_000; // Higher limit to avoid underflow let messages = create_test_messages(); let result = summarize_messages( @@ -380,7 +395,7 @@ mod tests { async fn test_summarize_messages_empty_input() { let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); - let context_limit = 100; + let context_limit = 10_000; // Higher limit to avoid underflow let messages: Vec = Vec::new(); let result = summarize_messages( @@ -406,73 +421,264 @@ mod tests { } #[tokio::test] - async fn test_preprocess_messages_without_tool_response() { - let messages = create_test_messages(); - let (preprocessed_messages, removed_messages) = preprocess_messages(&messages); + async fn test_summarize_messages_uses_oneshot_for_small_context() { + let provider = create_mock_provider().expect("failed to create mock provider"); + let token_counter = TokenCounter::new(); + let context_limit = 100_000; // Large context limit + let messages = create_test_messages(); // Small message set + let result = summarize_messages( + Arc::clone(&provider), + &messages, + &token_counter, + context_limit, + ) + .await; + + assert!(result.is_ok(), "The function should return Ok."); + let (summarized_messages, _) = result.unwrap(); + + // Should use one-shot and return a single summarized message assert_eq!( - preprocessed_messages.len(), - 3, - "Only the user message should remain after preprocessing." + summarized_messages.len(), + 1, + "Should use one-shot summarization for small context." ); + } + + #[tokio::test] + async fn test_summarize_messages_uses_chunked_for_large_context() { + let provider = create_mock_provider().expect("failed to create mock provider"); + let token_counter = TokenCounter::new(); + let context_limit = 10_000; // Higher limit to avoid underflow + let messages = create_test_messages(); + + let result = summarize_messages( + Arc::clone(&provider), + &messages, + &token_counter, + context_limit, + ) + .await; + + assert!(result.is_ok(), "The function should return Ok."); + let (summarized_messages, _) = result.unwrap(); + + // Should fall back to chunked approach assert_eq!( - removed_messages.len(), - 0, - "The tool request and tool response messages should be removed." + summarized_messages.len(), + 1, + "Should use chunked summarization for large context." ); } + // Mock provider that fails on one-shot but succeeds on chunked + #[derive(Clone)] + struct FailingOneshotProvider { + model_config: ModelConfig, + call_count: Arc>, + } + + #[async_trait::async_trait] + impl Provider for FailingOneshotProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let mut count = self.call_count.lock().unwrap(); + *count += 1; + + // Fail if this looks like a one-shot request + if system.contains("reasoning in `` tags") { + return Err(ProviderError::RateLimitExceeded( + "Simulated one-shot failure".to_string(), + )); + } + + // Succeed for chunked requests (uses the old SUMMARY_PROMPT) + Ok(( + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text( + RawTextContent { + text: "Chunked summary".to_string(), + } + .no_annotation(), + )], + ), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } + } + #[tokio::test] - async fn test_preprocess_messages_with_tool_response() { - let arguments = json!({ - "param1": "value1" + async fn test_summarize_messages_fallback_on_oneshot_failure() { + let call_count = Arc::new(std::sync::Mutex::new(0)); + let provider = Arc::new(FailingOneshotProvider { + model_config: ModelConfig::new("test-model") + .unwrap() + .with_context_limit(200_000.into()), + call_count: Arc::clone(&call_count), }); - let messages = vec![ - set_up_text_message("Message 1", Role::User), - set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))), - set_up_tool_response_message("id", vec![Content::text("tool done")]), - ]; + let token_counter = TokenCounter::new(); + let context_limit = 100_000; // Large enough to try one-shot first + let messages = create_test_messages(); + + let result = summarize_messages(provider, &messages, &token_counter, context_limit).await; + + assert!( + result.is_ok(), + "The function should return Ok after fallback." + ); + let (summarized_messages, _) = result.unwrap(); + + // Should have fallen back to chunked approach + assert_eq!( + summarized_messages.len(), + 1, + "Should successfully fall back to chunked approach." + ); + + // Verify the content comes from the chunked approach + if let MessageContent::Text(text_content) = &summarized_messages[0].content[0] { + assert_eq!(text_content.text, "Chunked summary"); + } else { + panic!("Expected text content"); + } + + // Should have made multiple calls (one-shot attempt + chunked calls) + let final_count = *call_count.lock().unwrap(); + assert!( + final_count > 1, + "Should have made multiple provider calls during fallback" + ); + } + + #[tokio::test] + async fn test_summarize_messages_oneshot_direct_call() { + let provider = create_mock_provider().expect("failed to create mock provider"); + let token_counter = TokenCounter::new(); + let context_limit = 100_000; + let messages = create_test_messages(); + + let result = summarize_messages_oneshot( + Arc::clone(&provider), + &messages, + &token_counter, + context_limit, + ) + .await; - let (preprocessed_messages, removed_messages) = preprocess_messages(&messages); + assert!( + result.is_ok(), + "One-shot summarization should work directly." + ); + let (summarized_messages, token_counts) = result.unwrap(); assert_eq!( - preprocessed_messages.len(), + summarized_messages.len(), 1, - "Only the user message should remain after preprocessing." + "One-shot should return a single summary message." + ); + assert_eq!( + summarized_messages[0].role, + Role::User, + "Summary should be from user role for context." ); assert_eq!( - removed_messages.len(), - 2, - "The tool request and tool response messages should be removed." + token_counts.len(), + 1, + "Should have token count for the summary." ); } #[tokio::test] - async fn test_reintegrate_removed_messages() { - let summarized_messages = vec![Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: "Summary".to_string(), - } - .no_annotation(), - )], - )]; - let arguments = json!({ - "param1": "value1" - }); - let removed_messages = vec![ - set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))), - set_up_tool_response_message("id", vec![Content::text("tool done")]), - ]; + async fn test_summarize_messages_chunked_direct_call() { + let provider = create_mock_provider().expect("failed to create mock provider"); + let token_counter = TokenCounter::new(); + let context_limit = 10_000; // Higher limit to avoid underflow + let messages = create_test_messages(); + + let result = summarize_messages_chunked( + Arc::clone(&provider), + &messages, + &token_counter, + context_limit, + ) + .await; - let final_messages = reintegrate_removed_messages(&summarized_messages, &removed_messages); + assert!( + result.is_ok(), + "Chunked summarization should work directly." + ); + let (summarized_messages, token_counts) = result.unwrap(); assert_eq!( - final_messages.len(), - 3, - "The final message list should include the summary and removed messages." + summarized_messages.len(), + 1, + "Chunked should return a single final summary." + ); + assert_eq!( + summarized_messages[0].role, + Role::User, + "Summary should be from user role for context." + ); + assert_eq!( + token_counts.len(), + 1, + "Should have token count for the summary." + ); + } + + #[tokio::test] + async fn test_absolute_token_threshold_calculation() { + let provider = create_mock_provider().expect("failed to create mock provider"); + let token_counter = TokenCounter::new(); + + // Test with a context limit where absolute token calculation matters + let context_limit = 10_000; + let system_prompt_overhead = 1000; + let response_overhead = 4000; + let safety_buffer = 1000; + let max_message_tokens = + context_limit - system_prompt_overhead - response_overhead - safety_buffer; // 4000 tokens + + // Create messages that are just under the absolute threshold + let mut large_messages = Vec::new(); + let base_message = set_up_text_message("x".repeat(50).as_str(), Role::User); + + // Add enough messages to approach but not exceed the absolute threshold + let message_tokens = token_counter.count_tokens(&format!("{:?}", base_message)); + let num_messages = (max_message_tokens / message_tokens).saturating_sub(1); + + for i in 0..num_messages { + large_messages.push(set_up_text_message(&format!("Message {}", i), Role::User)); + } + + let result = summarize_messages( + Arc::clone(&provider), + &large_messages, + &token_counter, + context_limit, + ) + .await; + + assert!( + result.is_ok(), + "Should handle absolute threshold calculation correctly." ); + let (summarized_messages, _) = result.unwrap(); + assert_eq!(summarized_messages.len(), 1, "Should produce a summary."); } } diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index dad3f74c2bda..6799c01ad9c0 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -329,6 +329,7 @@ mod tests { #[test] fn test_valid_configurations() { + // Test with environment variables set with_var("GOOSE_CONTEXT_LIMIT", Some("50000"), || { with_var("GOOSE_TEMPERATURE", Some("0.7"), || { with_var("GOOSE_TOOLSHIM", Some("true"), || { diff --git a/crates/goose/src/prompts/summarize_oneshot.md b/crates/goose/src/prompts/summarize_oneshot.md new file mode 100644 index 000000000000..8e621f2058aa --- /dev/null +++ b/crates/goose/src/prompts/summarize_oneshot.md @@ -0,0 +1,26 @@ +## Summary Task +Generate detailed summary of conversation to date. +Include user requests, your responses, and all technical content. + +Wrap reasoning in `` tags: +- Review conversation chronologically +- For each part, log: + - User goals and requests + - Your method and solution + - Key decisions and designs + - File names, code, signatures, errors, fixes +- Highlight user feedback and revisions +- Confirm completeness and accuracy + +### Summary Must Include the Following Sections: +1. **User Intent** – All goals and requests +2. **Technical Concepts** – All discussed tools, methods +3. **Files + Code** – Viewed/edited files, full code, change justifications +4. **Errors + Fixes** – Bugs, resolutions, user-driven changes +5. **Problem Solving** – Issues solved or in progress +6. **User Messages** – All user messages, exclude tool output +7. **Pending Tasks** – All unresolved user requests +8. **Current Work** – Active work at summary request time: filenames, code, alignment to latest instruction +9. **Next Step** – *Include only if* directly continues user instruction + +> No new ideas unless user confirmed diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 833fd4547aa4..93a419beeeac 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -518,6 +518,7 @@ impl Provider for ClaudeCodeProvider { mod tests { use super::ModelConfig; use super::*; + use temp_env::with_var; #[test] fn test_claude_code_model_config() { diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 0c6a304b09ee..5541b2a34f4c 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1236,7 +1236,9 @@ async fn run_scheduled_job_internal( Ok(AgentEvent::ModelChange { .. }) => { // Model change events are informational, just continue } - + Ok(AgentEvent::HistoryReplaced(_)) => { + // Handle history replacement events if needed + } Err(e) => { tracing::error!( "[Job {}] Error receiving message from agent: {}", diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 8ef7854576b8..cc77eb5f15b7 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -143,7 +143,9 @@ async fn run_truncate_test( Ok(AgentEvent::ModelChange { .. }) => { // Model change events are informational, just continue } - + Ok(AgentEvent::HistoryReplaced(_)) => { + // Handle history replacement events if needed + } Err(e) => { println!("Error: {:?}", e); return Err(e); @@ -1043,6 +1045,7 @@ mod max_turns_tests { } Ok(AgentEvent::McpNotification(_)) => {} Ok(AgentEvent::ModelChange { .. }) => {} + Ok(AgentEvent::HistoryReplaced(_)) => {} Err(e) => { return Err(e); }