diff --git a/Cargo.lock b/Cargo.lock index 26b559153d7b..d84047ebbdd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3600,6 +3600,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", + "tokio-stream", "tracing", "tracing-appender", "tracing-subscriber", diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 8bf405e7dfb3..cfa787a75042 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -12,6 +12,7 @@ mcp-core = { path = "../mcp-core" } mcp-server = { path = "../mcp-server" } anyhow = "1.0.94" tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1", features = ["io-util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index df95f76dfee7..0c4976e4183e 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -5,19 +5,22 @@ use anyhow::Result; use base64::Engine; use etcetera::{choose_app_strategy, AppStrategy}; use indoc::formatdoc; +use regex::Regex; use serde_json::{json, Value}; use std::{ collections::HashMap, future::Future, - io::Cursor, + io::{Cursor, Write as _}, path::{Path, PathBuf}, pin::Pin, }; +use tempfile::NamedTempFile; use tokio::{ io::{AsyncBufReadExt, BufReader}, process::Command, sync::mpsc, }; +use tokio_stream::{wrappers::SplitStream, StreamExt as _}; use url::Url; use include_dir::{include_dir, Dir}; @@ -48,6 +51,9 @@ use xcap::{Monitor, Window}; use ignore::gitignore::{Gitignore, GitignoreBuilder}; +// Avoid spamming the LLM by default +const MAX_SHELL_INLINE_CHAR_COUNT: usize = 10_000; + // Embeds the prompts directory to the build static PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/developer/prompts"); @@ -94,12 +100,44 @@ pub fn load_prompt_files() -> HashMap { prompts } +struct TmpfileOutput { + tmpf: NamedTempFile, + path: String, + /// Writing the output but also automatically returning lines + /// which match a regex. + highlight: Option, +} + +impl TmpfileOutput { + fn new(highlight: Option) -> std::io::Result { + let tmpf = tempfile::NamedTempFile::with_prefix("goose-shell-output-")?; + let path = tmpf + .path() + .to_str() + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to get path for temporary file", + ) + })? + .to_string(); + let r = Self { + tmpf, + path, + highlight, + }; + Ok(r) + } +} + pub struct DeveloperRouter { tools: Vec, prompts: Arc>, instructions: String, file_history: Arc>>>, ignore_patterns: Arc, + // The last (standard) output from a shell command + previous_stdout: Arc>>, } impl Default for DeveloperRouter { @@ -139,14 +177,39 @@ impl DeveloperRouter { _ => indoc! {r#" Execute a command in the shell. - This will return the output and error concatenated into a single string, as + By default this will return the output and error concatenated into a single string, as you would see from running on the command line. There will also be an indication of if the command succeeded or failed. - Avoid commands that produce a large amount of output, and consider piping those outputs to files. + ## Handling output + + **Important** you may find that some commands output significant amounts of text (such as + a software build command like `make`) and most of the time you will only care about the error + messages. For these commands it's recommended to set `output_tmpfile`. If set, + the output will be saved to a temporary file; you can then read it if needed - and for example + you *should* look at the saved output if the tool errors. + The temporary output file will be automatically removed after invoking the next shell command that uses + `output_tmpfile`. + + ## Selectively acquiring output text with a regular expression + + The `output_tmpfile` parameter can be the empty string; if so, then the first text result will always + be empty. + + However if a non-empty string is provided it should be a regular expression that can be used to + automatically return matching content from the output. For example if you know a command may output text like + `warning: ...` then you could provide the regular expression "^(warning|error):.*" to have + that content automatically returned in the primary text, even if the command succeeds. Note the regular + expression must match the entire line (hence the trailing .*). There is not a trailing newline in the text to be matched. + Note if this filtered content itself is too large, it will be truncated (based on the start). + + ## Suggestions on long lived processes + If you need to run a long lived command, background it - e.g. `uvicorn main:app &` so that this tool does not run indefinitely. + ## Other tips + **Important**: Each shell command runs in its own process. Things like directory changes or sourcing files do not persist between tool calls. So you may need to repeat them each time by stringing together commands, e.g. `cd example && ls` or `source env/bin/activate && pip install numpy` @@ -165,7 +228,8 @@ impl DeveloperRouter { "type": "object", "required": ["command"], "properties": { - "command": {"type": "string"} + "command": {"type": "string"}, + "output_tmpfile": {"type": "string"} } }), None, @@ -443,6 +507,7 @@ impl DeveloperRouter { instructions, file_history: Arc::new(Mutex::new(HashMap::new())), ignore_patterns: Arc::new(ignore_patterns), + previous_stdout: Arc::new(Default::default()), } } @@ -483,6 +548,24 @@ impl DeveloperRouter { "The command string is required".to_string(), ))?; + let output_tmpfile = params.get("output_tmpfile").and_then(|v| v.as_str()); + let output_tmpfile = match output_tmpfile { + Some(s) => { + let re = if s.is_empty() { + None + } else { + let re = regex::Regex::new(s).map_err(|e| { + ToolError::InvalidParameters(format!("Invalid output_tmpfile regexp: {e}")) + })?; + Some(re) + }; + Some(TmpfileOutput::new(re).map_err(|e| { + ToolError::InternalError(format!("Failed to create tmpfile: {e}")) + })?) + } + None => None, + }; + // Check if command might access ignored files and return early if it does let cmd_parts: Vec<&str> = command.split_whitespace().collect(); for arg in &cmd_parts[1..] { @@ -519,77 +602,64 @@ impl DeveloperRouter { .spawn() .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - let stdout = child.stdout.take().unwrap(); - let stderr = child.stderr.take().unwrap(); - - let mut stdout_reader = BufReader::new(stdout); - let mut stderr_reader = BufReader::new(stderr); + let stdout = BufReader::new(child.stdout.take().unwrap()); + let stderr = BufReader::new(child.stderr.take().unwrap()); let output_task = tokio::spawn(async move { let mut combined_output = String::new(); - - let mut stdout_buf = Vec::new(); - let mut stderr_buf = Vec::new(); - - let mut stdout_done = false; - let mut stderr_done = false; - - loop { - tokio::select! { - n = stdout_reader.read_until(b'\n', &mut stdout_buf), if !stdout_done => { - if n? == 0 { - stdout_done = true; - } else { - let line = String::from_utf8_lossy(&stdout_buf); - - notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "shell", - "stream": "stdout", - "output": line.to_string(), - } - })), - })).ok(); - - combined_output.push_str(&line); - stdout_buf.clear(); - } - } - - n = stderr_reader.read_until(b'\n', &mut stderr_buf), if !stderr_done => { - if n? == 0 { - stderr_done = true; - } else { - let line = String::from_utf8_lossy(&stderr_buf); - - notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "shell", - "stream": "stderr", - "output": line.to_string(), - } - })), - })).ok(); - - combined_output.push_str(&line); - stderr_buf.clear(); - } - } - - else => break, + let mut output_tmpfile = output_tmpfile; + + // We have the individual two streams above, now merge them into one unified stream of + // an enum. ref https://blog.yoshuawuyts.com/futures-concurrency-3 + let stdout = SplitStream::new(stdout.split(b'\n')).map(|v| ("stdout", v)); + let stderr = SplitStream::new(stderr.split(b'\n')).map(|v| ("stderr", v)); + let mut merged = stdout.merge(stderr); + + while let Some((key, line)) = merged.next().await { + let mut line = line?; + // Re-add this as clients expect it + line.push(b'\n'); + // Here we always convert to UTF-8 so agents don't have to deal with corrupted output + let line = String::from_utf8_lossy(&line); + + // This keeps the user updated with progress of commands + notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "shell", + "stream": key, + "output": line, + } + })), + })) + .ok(); + + // Enable automatic redirection + if output_tmpfile.is_none() + && (combined_output.len() + line.len()) > MAX_SHELL_INLINE_CHAR_COUNT + { + let mut new_tmpf = TmpfileOutput::new(None)?; + new_tmpf.tmpf.write_all(combined_output.as_bytes())?; + output_tmpfile = Some(new_tmpf); + combined_output.clear(); } - if stdout_done && stderr_done { - break; + let Some(redirect) = output_tmpfile.as_mut() else { + combined_output.push_str(&line); + continue; + }; + if let Some(re) = redirect.highlight.as_ref() { + if re.is_match(line.trim_end_matches('\n')) { + combined_output.push_str(&line); + } } + redirect.tmpf.write_all(line.as_bytes())?; } - Ok::<_, std::io::Error>(combined_output) + + Ok::<_, std::io::Error>((combined_output, output_tmpfile)) }); // Wait for the command to complete and get output @@ -598,29 +668,23 @@ impl DeveloperRouter { .await .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - let output_str = match output_task.await { + let (output_str, output_tmpfile) = match output_task.await { Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?, Err(e) => return Err(ToolError::ExecutionError(e.to_string())), }; - // Check the character count of the output - const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB - let char_count = output_str.chars().count(); - if char_count > MAX_CHAR_COUNT { - return Err(ToolError::ExecutionError(format!( - "Shell output from command '{}' has too many characters ({}). Maximum character count is {}.", - command, - char_count, - MAX_CHAR_COUNT - ))); + let mut r = vec![Content::text(output_str.clone()).with_audience(vec![Role::Assistant])]; + if let Some(output_tmpfile) = output_tmpfile { + r.push(Content::text(output_tmpfile.path).with_audience(vec![Role::Assistant])); + let mut previous = self.previous_stdout.lock().unwrap(); + *previous = Some(output_tmpfile.tmpf); } - - Ok(vec![ - Content::text(output_str.clone()).with_audience(vec![Role::Assistant]), + r.push( Content::text(output_str) .with_audience(vec![Role::User]) .with_priority(0.0), - ]) + ); + Ok(r) } async fn text_editor(&self, params: Value) -> Result, ToolError> { @@ -1209,6 +1273,7 @@ impl Clone for DeveloperRouter { instructions: self.instructions.clone(), file_history: Arc::clone(&self.file_history), ignore_patterns: Arc::clone(&self.ignore_patterns), + previous_stdout: Arc::clone(&self.previous_stdout), } } } @@ -1218,7 +1283,8 @@ mod tests { use super::*; use serde_json::json; use serial_test::serial; - use std::fs; + use std::io::Write; + use std::{fs, io::BufWriter}; use tempfile::TempDir; use tokio::sync::OnceCell; @@ -1631,6 +1697,7 @@ mod tests { instructions: String::new(), file_history: Arc::new(Mutex::new(HashMap::new())), ignore_patterns: Arc::new(ignore_patterns), + previous_stdout: Default::default(), }; // Test basic file matching @@ -1681,6 +1748,7 @@ mod tests { instructions: String::new(), file_history: Arc::new(Mutex::new(HashMap::new())), ignore_patterns: Arc::new(ignore_patterns), + previous_stdout: Default::default(), }; // Try to write to an ignored file @@ -1740,6 +1808,7 @@ mod tests { instructions: String::new(), file_history: Arc::new(Mutex::new(HashMap::new())), ignore_patterns: Arc::new(ignore_patterns), + previous_stdout: Default::default(), }; // Create an ignored file @@ -1973,4 +2042,398 @@ mod tests { temp_dir.close().unwrap(); } + + #[tokio::test] + #[serial] + async fn test_shell_output_tmpfile() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Call the shell tool with redirection enabled + let command = "echo 'first line' && echo 'second line'"; + let result = router + .call_tool( + "shell", + json!({ + "command": command, + "output_tmpfile": "" + }), + dummy_sender(), + ) + .await + .unwrap(); + + // Verify the output indicates redirection + let output = result + .iter() + .filter(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .collect::>(); + assert_eq!(output.len(), 2); + assert_eq!(output[0].as_text().unwrap().len(), 0); + let path = output[1].as_text().unwrap(); + // Read the content from the temporary file + let file_content = tokio::fs::read_to_string(path).await.unwrap(); + + // The first line will be the redirected path, subsequent lines are the actual stdout + let mut lines = file_content.lines(); + assert_eq!(lines.next().unwrap().trim(), "first line"); + assert_eq!(lines.next().unwrap().trim(), "second line"); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_automatic_redirection() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Create a command that outputs more than MAX_SHELL_INLINE_CHAR_COUNT + let token = "aaaaaaaaaaaaaaaaaaaa"; + let large = MAX_SHELL_INLINE_CHAR_COUNT.div_ceil(token.len()) + 1; + let large_output_command = format!("for x in $(seq {large}); do echo -n {token}; done"); + + let result = router + .call_tool( + "shell", + json!({ + "command": large_output_command + }), + dummy_sender(), + ) + .await + .unwrap(); + + let output = result + .iter() + .filter(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .collect::>(); + assert_eq!(output.len(), 2); + assert_eq!(output[0].as_text().unwrap().len(), 0); + // The second content should be the path to the temporary file + let temp_file_path_content = output[1].as_text().unwrap(); + assert!(!temp_file_path_content.is_empty()); + let temp_file_path = Path::new(temp_file_path_content); + let meta = temp_file_path.symlink_metadata().unwrap(); + assert!(meta.is_file()); + let file_content = tokio::fs::read_to_string(temp_file_path).await.unwrap(); + let size = file_content.trim().chars().fold(0usize, |acc, v| { + if v == 'a' { + acc + 1 + } else { + panic!("Unexpected char in output: {v:?}") + } + }); + assert!(size > MAX_SHELL_INLINE_CHAR_COUNT); + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_output_tmpfile_filtered() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + let matcher = r"^(error|warning):.*"; + + // Now testing something like a typical compiler output + let command = "for x in $(seq 100); do echo \"some output\"; done; echo error: oops; echo warning: blah; echo other text; echo error: another error; for x in $(seq 100); do echo \"more output\"; done"; + let result = router + .call_tool( + "shell", + json!({ + "command": command, + "output_tmpfile": matcher + }), + dummy_sender(), + ) + .await + .unwrap(); + + // Verify the output indicates redirection + let output = result + .iter() + .filter(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .collect::>(); + assert_eq!(output.len(), 2); + let output0 = output[0].as_text().unwrap(); + assert_eq!( + output0.lines().filter(|v| v.starts_with("error:")).count(), + 2 + ); + assert_eq!( + output0 + .lines() + .filter(|v| v.starts_with("warning:")) + .count(), + 1 + ); + let path = output[1].as_text().unwrap(); + let path = Path::new(path); + let meta = path.symlink_metadata().unwrap(); + assert!(meta.is_file()); + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + assert_eq!(meta.size(), 2458); + } + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_stream_merge_ordering() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + let result = router + .call_tool( + "shell", + json!({ + "command": "for x in $(seq 5); do echo stdout$x{x}; echo stderr${x} >&2; sleep 0.1; done" + }), + dummy_sender(), + ) + .await + .unwrap(); + + // Check the output includes interleaved stdout/stderr content + let output = result + .iter() + .find(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .unwrap() + .as_text() + .unwrap(); + + let lines: Vec<&str> = output.lines().collect(); + assert!(lines.len() >= 10, "Should have at least 10 lines of output"); + + // Verify some stdout lines exist + assert!(lines.iter().any(|&line| line.contains("stdout"))); + // Verify some stderr lines exist + assert!(lines.iter().any(|&line| line.contains("stderr"))); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_non_utf8_output() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Create a script that outputs some non-UTF8 bytes + let command = r#" +# Output some invalid UTF-8 sequences +printf '\xFF\xFE\xFF\xFE\n' +echo "Valid UTF8 after invalid" +printf '\x80\x81\x82\n' >&2 +echo "More valid text" >&2"#; + + let result = router + .call_tool( + "shell", + json!({ + "command": command + }), + dummy_sender(), + ) + .await + .unwrap(); + + let output = result + .iter() + .find(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .unwrap() + .as_text() + .unwrap(); + + // Check that we got the valid UTF8 parts + assert!(output.contains("Valid UTF8 after invalid")); + assert!(output.contains("More valid text")); + // The invalid sequences should be replaced with the UTF-8 replacement character + assert!(output.contains("�")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_stream_merge_large_output() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Create a script that outputs lots of data to both streams + let command = r#" +# Output 1000 lines to stdout +for i in {1..1000}; do + echo "stdout line $i" +done + +# Output 1000 lines to stderr +for i in {1..1000}; do + echo "stderr line $i" >&2 +done"#; + + let result = router + .call_tool( + "shell", + json!({ + "command": command, + "output_tmpfile": "" + }), + dummy_sender(), + ) + .await + .unwrap(); + + // Since we're using output_tmpfile, we should get a path + let outputs: Vec<_> = result + .iter() + .filter(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .collect(); + + assert_eq!(outputs.len(), 2); + assert_eq!(outputs[0].as_text().unwrap().len(), 0); // Empty initial output + let tmpfile_path = outputs[1].as_text().unwrap(); + + // Read the tmpfile content + let content = std::fs::read_to_string(tmpfile_path).unwrap(); + + // Verify we got both stdout and stderr content + assert!(content.contains("stdout line 1")); + assert!(content.contains("stdout line 1000")); + assert!(content.contains("stderr line 1")); + assert!(content.contains("stderr line 1000")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_stream_merge_zero_length() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Test handling of zero-length outputs and EOF conditions + let command = r#" +printf '' > /dev/stdout +printf '' > /dev/stderr +echo -n "no newline stdout" > /dev/stdout +echo -n "no newline stderr" > /dev/stderr"#; + + let result = router + .call_tool( + "shell", + json!({ + "command": command, + }), + dummy_sender(), + ) + .await + .unwrap(); + + let output = result + .iter() + .find(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .unwrap() + .as_text() + .unwrap(); + + // We should get both outputs even without newlines + assert!(output.contains("no newline stdout")); + assert!(output.contains("no newline stderr")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + #[serial] + async fn test_shell_stream_merge_binary_output() { + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(&temp_dir).unwrap(); + + let router = get_router().await; + + // Create a small binary file for testing + let mut binary_file = std::fs::File::create("test.bin") + .map(BufWriter::new) + .unwrap(); + for i in 0u8..255u8 { + binary_file.write_all(&[i]).unwrap(); + } + binary_file.flush().unwrap(); + drop(binary_file); + + // Create a script that outputs binary data mixed with text + let command = r#" +cat test.bin +echo "Text after binary" +cat test.bin >&2 +echo "Text after binary stderr" >&2"#; + + let result = router + .call_tool( + "shell", + json!({ + "command": command, + }), + dummy_sender(), + ) + .await + .unwrap(); + + let output = result + .iter() + .find(|c| { + c.audience() + .is_some_and(|roles| roles.contains(&Role::Assistant)) + }) + .unwrap() + .as_text() + .unwrap(); + + // Verify we got the text portions + assert!(output.contains("Text after binary")); + // Verify binary data was handled (converted to replacement chars where needed) + assert!(output.contains("�")); + + temp_dir.close().unwrap(); + } } diff --git a/crates/mcp-core/src/handler.rs b/crates/mcp-core/src/handler.rs index 2a4c1a77ee2f..912775ce3e24 100644 --- a/crates/mcp-core/src/handler.rs +++ b/crates/mcp-core/src/handler.rs @@ -12,6 +12,8 @@ use utoipa::ToSchema; pub enum ToolError { #[error("Invalid parameters: {0}")] InvalidParameters(String), + #[error("Internal error: {0}")] + InternalError(String), #[error("Execution failed: {0}")] ExecutionError(String), #[error("Schema error: {0}")]