diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index b7148f8864a5..1805d59e4403 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -29,13 +29,68 @@ struct CliProcess { child: tokio::process::Child, stdin: tokio::process::ChildStdin, reader: BufReader, - #[allow(dead_code)] stderr_handle: tokio::task::JoinHandle, messages_sent: usize, } +impl CliProcess { + async fn send_and_read( + &mut self, + content_blocks: &[Value], + ) -> Result, ProviderError> { + let ndjson_line = build_stream_json_input(content_blocks); + self.stdin + .write_all(ndjson_line.as_bytes()) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e)) + })?; + self.stdin.write_all(b"\n").await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e)) + })?; + + let mut lines = Vec::new(); + let mut line = String::new(); + + loop { + line.clear(); + match self.reader.read_line(&mut line).await { + Ok(0) => { + return Err(ProviderError::RequestFailed( + "Claude CLI process terminated unexpectedly".to_string(), + )); + } + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + lines.push(trimmed.to_string()); + + if let Ok(parsed) = serde_json::from_str::(trimmed) { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("result") => break, + Some("error") => break, + _ => {} + } + } + } + Err(e) => { + return Err(ProviderError::RequestFailed(format!( + "Failed to read output: {}", + e + ))); + } + } + } + + Ok(lines) + } +} + impl Drop for CliProcess { fn drop(&mut self) { + self.stderr_handle.abort(); let _ = self.child.start_kill(); } } @@ -52,7 +107,7 @@ pub struct ClaudeCodeProvider { #[serde(skip)] name: String, #[serde(skip)] - cli_process: tokio::sync::OnceCell>, + cli_process: tokio::sync::Mutex>, } impl ClaudeCodeProvider { @@ -65,7 +120,7 @@ impl ClaudeCodeProvider { command: resolved_command, model, name: CLAUDE_CODE_PROVIDER_NAME.to_string(), - cli_process: tokio::sync::OnceCell::new(), + cli_process: tokio::sync::Mutex::new(None), }) } @@ -252,6 +307,71 @@ impl ClaudeCodeProvider { Ok((response_message, usage)) } + fn spawn_process(&self, filtered_system: &str) -> Result { + let mut cmd = Command::new(&self.command); + configure_command_no_window(&mut cmd); + cmd.arg("--input-format") + .arg("stream-json") + .arg("--output-format") + .arg("stream-json") + .arg("--verbose") + .arg("--system-prompt") + .arg(filtered_system); + + if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { + cmd.arg("--model").arg(&self.model.model_name); + } + + Self::apply_permission_flags(&mut cmd)?; + + cmd.stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = cmd.spawn().map_err(|e| { + ProviderError::RequestFailed(format!( + "Failed to spawn Claude CLI command '{:?}': {}.", + self.command, e + )) + })?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdin".to_string()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; + + // Drain stderr concurrently to prevent pipe buffer deadlock + let stderr = child.stderr.take(); + let stderr_handle = tokio::spawn(async move { + let mut output = String::new(); + if let Some(mut stderr) = stderr { + use tokio::io::AsyncReadExt; + let _ = stderr.read_to_string(&mut output).await; + } + output + }); + + Ok(CliProcess { + child, + stdin, + reader: BufReader::new(stdout), + stderr_handle, + messages_sent: 0, + }) + } + + fn is_process_alive(process: &mut CliProcess) -> bool { + match process.child.try_wait() { + Ok(None) => true, + Ok(Some(_)) => false, + Err(_) => false, + } + } + async fn execute_command( &self, system: &str, @@ -268,146 +388,59 @@ impl ClaudeCodeProvider { "Filtered system prompt length: {} chars", filtered_system.len() ); - println!("Filtered system prompt: {}", filtered_system); println!("================================"); } - // Spawn lazily on first call (OnceCell ensures exactly once) - let process_mutex = self - .cli_process - .get_or_try_init(|| async { - let mut cmd = Command::new(&self.command); - // NO -p flag — persistent mode - configure_command_no_window(&mut cmd); - cmd.arg("--input-format") - .arg("stream-json") - .arg("--output-format") - .arg("stream-json") - .arg("--verbose") - // System prompt is set once at process start. The provider - // instance is not reused across sessions with different prompts. - .arg("--system-prompt") - .arg(&filtered_system); - - // Only pass model parameter if it's in the known models list - if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { - cmd.arg("--model").arg(&self.model.model_name); - } + let mut guard = self.cli_process.lock().await; - // Add permission mode based on GOOSE_MODE setting - Self::apply_permission_flags(&mut cmd)?; - - cmd.stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()); - - let mut child = cmd.spawn().map_err(|e| { - ProviderError::RequestFailed(format!( - "Failed to spawn Claude CLI command '{:?}': {}.", - self.command, e - )) - })?; - - let stdin = child.stdin.take().ok_or_else(|| { - ProviderError::RequestFailed("Failed to capture stdin".to_string()) - })?; - let stdout = child.stdout.take().ok_or_else(|| { - ProviderError::RequestFailed("Failed to capture stdout".to_string()) - })?; - - // Drain stderr concurrently to prevent pipe buffer deadlock - let stderr = child.stderr.take(); - let stderr_handle = tokio::spawn(async move { - let mut output = String::new(); - if let Some(mut stderr) = stderr { - use tokio::io::AsyncReadExt; - let _ = stderr.read_to_string(&mut output).await; - } - output - }); - - Ok::<_, ProviderError>(tokio::sync::Mutex::new(CliProcess { - child, - stdin, - reader: BufReader::new(stdout), - stderr_handle, - messages_sent: 0, - })) - }) - .await?; + let needs_spawn = match guard.as_mut() { + None => true, + Some(p) => !Self::is_process_alive(p), + }; + if needs_spawn { + *guard = Some(self.spawn_process(&filtered_system)?); + } - let mut process = process_mutex.lock().await; + let new_messages = self.content_blocks_for(guard.as_ref().unwrap(), messages); + let process = guard.as_mut().unwrap(); + match process.send_and_read(&new_messages).await { + Ok(lines) => { + process.messages_sent = messages.len(); + tracing::debug!("Command executed successfully, got {} lines", lines.len()); + Ok(lines) + } + Err(e) if Self::is_recoverable(&e) => { + tracing::debug!("CLI process dead, respawning"); + let process = guard.insert(self.spawn_process(&filtered_system)?); + let new_messages = self.content_blocks_for(process, messages); + let lines = process.send_and_read(&new_messages).await?; + process.messages_sent = messages.len(); + tracing::debug!( + "Command executed successfully after respawn, got {} lines", + lines.len() + ); + Ok(lines) + } + Err(e) => Err(e), + } + } - // Build content from new messages only (skip already-sent ones). - // If messages is shorter than messages_sent, the caller started a fresh - // conversation on the same provider instance — send everything. + fn content_blocks_for(&self, process: &CliProcess, messages: &[Message]) -> Vec { let new_messages = if process.messages_sent > 0 && process.messages_sent < messages.len() { &messages[process.messages_sent..] } else { messages }; - let new_blocks = self.messages_to_content_blocks(new_messages); - - // Write NDJSON line to stdin - let ndjson_line = build_stream_json_input(&new_blocks); - process - .stdin - .write_all(ndjson_line.as_bytes()) - .await - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e)) - })?; - process.stdin.write_all(b"\n").await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e)) - })?; - - // Read lines until we see a "result" event - let mut lines = Vec::new(); - let mut line = String::new(); - - loop { - line.clear(); - match process.reader.read_line(&mut line).await { - Ok(0) => { - // EOF means the process died - return Err(ProviderError::RequestFailed( - "Claude CLI process terminated unexpectedly".to_string(), - )); - } - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - lines.push(trimmed.to_string()); + self.messages_to_content_blocks(new_messages) + } - // Check if this is a result event (end of turn) - if let Ok(parsed) = serde_json::from_str::(trimmed) { - match parsed.get("type").and_then(|t| t.as_str()) { - Some("result") => break, - Some("error") => break, - _ => {} - } - } - } - Err(e) => { - return Err(ProviderError::RequestFailed(format!( - "Failed to read output: {}", - e - ))); - } + fn is_recoverable(err: &ProviderError) -> bool { + match err { + ProviderError::RequestFailed(msg) => { + msg.contains("Broken pipe") || msg.contains("terminated unexpectedly") } + _ => false, } - - // Update messages_sent for next turn - process.messages_sent = messages.len(); - - tracing::debug!("Command executed successfully, got {} lines", lines.len()); - for (i, line) in lines.iter().enumerate() { - tracing::debug!("Line {}: {}", i, line); - } - - Ok(lines) } /// Generate a simple session description without calling subprocess @@ -734,12 +767,30 @@ mod tests { ); } + #[test] + fn test_is_recoverable() { + assert!(ClaudeCodeProvider::is_recoverable( + &ProviderError::RequestFailed( + "Failed to write to stdin: Broken pipe (os error 32)".into() + ) + )); + assert!(ClaudeCodeProvider::is_recoverable( + &ProviderError::RequestFailed("Claude CLI process terminated unexpectedly".into()) + )); + assert!(!ClaudeCodeProvider::is_recoverable( + &ProviderError::RequestFailed("Failed to read output: connection reset".into()) + )); + assert!(!ClaudeCodeProvider::is_recoverable( + &ProviderError::Authentication("Broken pipe".into()) // wrong variant + )); + } + fn make_provider() -> ClaudeCodeProvider { ClaudeCodeProvider { command: PathBuf::from("claude"), model: ModelConfig::new("sonnet").unwrap(), name: "claude-code".to_string(), - cli_process: tokio::sync::OnceCell::new(), + cli_process: tokio::sync::Mutex::new(None), } } }