Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 99 additions & 21 deletions crates/goose/src/context_mgmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::conversation::message::{Message, MessageContent};
use crate::conversation::Conversation;
use crate::prompt_template::render_global_file;
use crate::providers::base::{Provider, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::{agents::Agent, config::Config, token_counter::create_token_counter};
use anyhow::Result;
use rmcp::model::Role;
Expand Down Expand Up @@ -219,6 +220,56 @@ pub async fn check_if_compaction_needed(
Ok(needs_compaction)
}

fn filter_tool_responses<'a>(messages: &[&'a Message], remove_percent: u32) -> Vec<&'a Message> {
fn has_tool_response(msg: &Message) -> bool {
msg.content
.iter()
.any(|c| matches!(c, MessageContent::ToolResponse(_)))
}

if remove_percent == 0 {
return messages.to_vec();
}

let tool_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, msg)| has_tool_response(msg))
.map(|(i, _)| i)
.collect();

if tool_indices.is_empty() {
return messages.to_vec();
}

let num_to_remove = ((tool_indices.len() * remove_percent as usize) / 100).max(1);

let middle = tool_indices.len() / 2;
let mut indices_to_remove = Vec::new();

// Middle out
for i in 0..num_to_remove {
if i % 2 == 0 {
let offset = i / 2;
if middle > offset {
indices_to_remove.push(tool_indices[middle - offset - 1]);
}
} else {
let offset = i / 2;
if middle + offset < tool_indices.len() {
indices_to_remove.push(tool_indices[middle + offset]);
}
}
}

messages
.iter()
.enumerate()
.filter(|(i, _)| !indices_to_remove.contains(i))
.map(|(_, msg)| *msg)
.collect()
}

async fn do_compact(
provider: Arc<dyn Provider>,
messages: &[Message],
Expand All @@ -228,34 +279,61 @@ async fn do_compact(
.filter(|msg| msg.is_agent_visible())
.collect();

let messages_text = agent_visible_messages
.iter()
.map(|&msg| format_message_for_compacting(msg))
.collect::<Vec<_>>()
.join("\n");
// Try progressively removing more tool response messages from the middle to reduce context length
let removal_percentages = vec![0, 10, 20, 50, 100];

let context = SummarizeContext {
messages: messages_text,
};
for (attempt, &remove_percent) in removal_percentages.iter().enumerate() {
let filtered_messages = filter_tool_responses(&agent_visible_messages, remove_percent);

let messages_text = filtered_messages
.iter()
.map(|&msg| format_message_for_compacting(msg))
.collect::<Vec<_>>()
.join("\n");

let system_prompt = render_global_file("summarize_oneshot.md", &context)?;
let context = SummarizeContext {
messages: messages_text,
};

let user_message = Message::user()
.with_text("Please summarize the conversation history provided in the system prompt.");
let summarization_request = vec![user_message];
let system_prompt = render_global_file("summarize_oneshot.md", &context)?;

let (mut response, mut provider_usage) = provider
.complete_fast(&system_prompt, &summarization_request, &[])
.await?;
let user_message = Message::user()
.with_text("Please summarize the conversation history provided in the system prompt.");
let summarization_request = vec![user_message];

response.role = Role::User;
match provider
.complete_fast(&system_prompt, &summarization_request, &[])
.await
{
Ok((mut response, mut provider_usage)) => {
response.role = Role::User;

provider_usage
.ensure_tokens(&system_prompt, &summarization_request, &response, &[])
.await
.map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?;
provider_usage
.ensure_tokens(&system_prompt, &summarization_request, &response, &[])
.await
.map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?;

Ok(Some((response, provider_usage)))
return Ok(Some((response, provider_usage)));
}
Err(e) => {
if matches!(e, ProviderError::ContextLengthExceeded(_)) {
if attempt < removal_percentages.len() - 1 {
continue;
} else {
return Err(anyhow::anyhow!(
"Failed to compact messages: context length still exceeded after {} attempts with maximum removal",
removal_percentages.len()
));
}
}
return Err(e.into());
}
}
}

Err(anyhow::anyhow!(
"Unexpected: exhausted all attempts without returning"
))
}

fn format_message_for_compacting(msg: &Message) -> String {
Expand Down
Loading