Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions crates/goose/src/conversation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,17 @@ pub fn fix_conversation(conversation: Conversation) -> (Conversation, Vec<String

fn fix_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (messages_1, empty_removed) = remove_empty_messages(messages);
let (messages_2, tool_calling_fixed) = fix_tool_calling(messages_1);
let (messages_3, messages_merged) = merge_consecutive_messages(messages_2);
// Merge consecutive messages BEFORE fixing tool calling, so tool request/response
// pairs aren't disrupted by merging after the fact
let (messages_2, messages_merged) = merge_consecutive_messages(messages_1);
let (messages_3, tool_calling_fixed) = fix_tool_calling(messages_2);
let (messages_4, lead_trail_fixed) = fix_lead_trail(messages_3);
let (messages_5, populated_if_empty) = populate_if_empty(messages_4);

let mut issues = Vec::new();
issues.extend(empty_removed);
issues.extend(tool_calling_fixed);
issues.extend(messages_merged);
issues.extend(tool_calling_fixed);
issues.extend(lead_trail_fixed);
issues.extend(populated_if_empty);

Expand All @@ -184,7 +186,30 @@ fn remove_empty_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>)

fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
let mut pending_tool_requests: HashSet<String> = HashSet::new();

// First pass: collect ALL tool request IDs from assistant messages (including agent_invisible)
let mut all_tool_requests: HashSet<String> = HashSet::new();
for message in &messages {
if message.role == Role::Assistant {
for content in &message.content {
if let MessageContent::ToolRequest(req) = content {
all_tool_requests.insert(req.id.clone());
}
}
}
}

// Collect tool request IDs that are in agent_visible assistant messages
let mut agent_visible_tool_requests: HashSet<String> = HashSet::new();
for message in &messages {
if message.role == Role::Assistant && message.metadata.agent_visible {
for content in &message.content {
if let MessageContent::ToolRequest(req) = content {
agent_visible_tool_requests.insert(req.id.clone());
}
}
}
}

for message in &mut messages {
let mut content_to_remove = Vec::new();
Expand Down Expand Up @@ -212,12 +237,18 @@ fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
issues.push("Removed thinking content from user message".to_string());
}
MessageContent::ToolResponse(resp) => {
if pending_tool_requests.contains(&resp.id) {
pending_tool_requests.remove(&resp.id);
} else {
// Remove tool responses that:
// 1. Don't have any matching tool request at all
// 2. Have a matching request but it's in an agent_invisible message
if !all_tool_requests.contains(&resp.id) {
content_to_remove.push(idx);
issues
.push(format!("Removed orphaned tool response '{}'", resp.id));
issues.push(format!("Removed orphaned tool response '{}'", resp.id));
} else if !agent_visible_tool_requests.contains(&resp.id) {
content_to_remove.push(idx);
issues.push(format!(
"Removed tool response '{}' whose request is in agent-invisible message",
resp.id
));
}
}
_ => {}
Expand All @@ -241,8 +272,8 @@ fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
req.id
));
}
MessageContent::ToolRequest(req) => {
pending_tool_requests.insert(req.id.clone());
MessageContent::ToolRequest(_) => {
// Don't remove tool requests from assistant messages
}
_ => {}
}
Expand All @@ -255,14 +286,27 @@ fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
}
}

// Second pass: Check for orphaned tool requests in agent_visible assistant messages
// (tool requests that don't have corresponding responses)
let mut tool_responses_seen: HashSet<String> = HashSet::new();
for message in &messages {
if message.role == Role::User {
for content in &message.content {
if let MessageContent::ToolResponse(resp) = content {
tool_responses_seen.insert(resp.id.clone());
}
}
}
}

for message in &mut messages {
if message.role == Role::Assistant {
if message.role == Role::Assistant && message.metadata.agent_visible {
let mut content_to_remove = Vec::new();
for (idx, content) in message.content.iter().enumerate() {
if let MessageContent::ToolRequest(req) = content {
if pending_tool_requests.contains(&req.id) {
if !tool_responses_seen.contains(&req.id) {
content_to_remove.push(idx);
issues.push(format!("Removed orphaned tool request '{}'", req.id));
issues.push(format!("Removed orphaned tool request '{}' from agent-visible assistant message", req.id));
}
}
}
Expand All @@ -271,6 +315,7 @@ fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
}
}
}

let (messages, empty_removed) = remove_empty_messages(messages);
issues.extend(empty_removed);
(messages, issues)
Expand Down Expand Up @@ -532,8 +577,8 @@ mod tests {

assert_eq!(fixed.len(), 5);
assert_eq!(issues.len(), 2);
assert!(issues[0].contains("Removed orphaned tool request"));
assert!(issues[1].contains("Merged consecutive assistant messages"));
assert!(issues[0].contains("Merged consecutive assistant messages"));
assert!(issues[1].contains("Removed orphaned tool request"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this switch?

}

#[test]
Expand Down