Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 181 additions & 130 deletions crates/goose/src/providers/claude_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,68 @@ struct CliProcess {
child: tokio::process::Child,
stdin: tokio::process::ChildStdin,
reader: BufReader<tokio::process::ChildStdout>,
#[allow(dead_code)]
stderr_handle: tokio::task::JoinHandle<String>,
messages_sent: usize,
}

impl CliProcess {
async fn send_and_read(
&mut self,
content_blocks: &[Value],
) -> Result<Vec<String>, ProviderError> {
let ndjson_line = build_stream_json_input(content_blocks);
self.stdin
.write_all(ndjson_line.as_bytes())
.await
.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e))
})?;
self.stdin.write_all(b"\n").await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e))
})?;

let mut lines = Vec::new();
let mut line = String::new();

loop {
line.clear();
match self.reader.read_line(&mut line).await {
Ok(0) => {
return Err(ProviderError::RequestFailed(
"Claude CLI process terminated unexpectedly".to_string(),
));
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
lines.push(trimmed.to_string());

if let Ok(parsed) = serde_json::from_str::<Value>(trimmed) {
match parsed.get("type").and_then(|t| t.as_str()) {
Some("result") => break,
Some("error") => break,
_ => {}
}
}
}
Err(e) => {
return Err(ProviderError::RequestFailed(format!(
"Failed to read output: {}",
e
)));
}
}
}

Ok(lines)
}
}

impl Drop for CliProcess {
fn drop(&mut self) {
self.stderr_handle.abort();
let _ = self.child.start_kill();
}
}
Expand All @@ -52,7 +107,7 @@ pub struct ClaudeCodeProvider {
#[serde(skip)]
name: String,
#[serde(skip)]
cli_process: tokio::sync::OnceCell<tokio::sync::Mutex<CliProcess>>,
cli_process: tokio::sync::Mutex<Option<CliProcess>>,
}

impl ClaudeCodeProvider {
Expand All @@ -65,7 +120,7 @@ impl ClaudeCodeProvider {
command: resolved_command,
model,
name: CLAUDE_CODE_PROVIDER_NAME.to_string(),
cli_process: tokio::sync::OnceCell::new(),
cli_process: tokio::sync::Mutex::new(None),
})
}

Expand Down Expand Up @@ -252,6 +307,71 @@ impl ClaudeCodeProvider {
Ok((response_message, usage))
}

fn spawn_process(&self, filtered_system: &str) -> Result<CliProcess, ProviderError> {
let mut cmd = Command::new(&self.command);
configure_command_no_window(&mut cmd);
cmd.arg("--input-format")
.arg("stream-json")
.arg("--output-format")
.arg("stream-json")
.arg("--verbose")
.arg("--system-prompt")
.arg(filtered_system);

if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) {
cmd.arg("--model").arg(&self.model.model_name);
}

Self::apply_permission_flags(&mut cmd)?;

cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());

let mut child = cmd.spawn().map_err(|e| {
ProviderError::RequestFailed(format!(
"Failed to spawn Claude CLI command '{:?}': {}.",
self.command, e
))
})?;

let stdin = child
.stdin
.take()
.ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;

// Drain stderr concurrently to prevent pipe buffer deadlock
let stderr = child.stderr.take();
let stderr_handle = tokio::spawn(async move {
let mut output = String::new();
if let Some(mut stderr) = stderr {
use tokio::io::AsyncReadExt;
let _ = stderr.read_to_string(&mut output).await;
}
output
});

Ok(CliProcess {
child,
stdin,
reader: BufReader::new(stdout),
stderr_handle,
messages_sent: 0,
})
}

fn is_process_alive(process: &mut CliProcess) -> bool {
match process.child.try_wait() {
Ok(None) => true,
Ok(Some(_)) => false,
Err(_) => false,
}
}

async fn execute_command(
&self,
system: &str,
Expand All @@ -268,146 +388,59 @@ impl ClaudeCodeProvider {
"Filtered system prompt length: {} chars",
filtered_system.len()
);
println!("Filtered system prompt: {}", filtered_system);
println!("================================");
}

// Spawn lazily on first call (OnceCell ensures exactly once)
let process_mutex = self
.cli_process
.get_or_try_init(|| async {
let mut cmd = Command::new(&self.command);
// NO -p flag — persistent mode
configure_command_no_window(&mut cmd);
cmd.arg("--input-format")
.arg("stream-json")
.arg("--output-format")
.arg("stream-json")
.arg("--verbose")
// System prompt is set once at process start. The provider
// instance is not reused across sessions with different prompts.
.arg("--system-prompt")
.arg(&filtered_system);

// Only pass model parameter if it's in the known models list
if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) {
cmd.arg("--model").arg(&self.model.model_name);
}
let mut guard = self.cli_process.lock().await;

// Add permission mode based on GOOSE_MODE setting
Self::apply_permission_flags(&mut cmd)?;

cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());

let mut child = cmd.spawn().map_err(|e| {
ProviderError::RequestFailed(format!(
"Failed to spawn Claude CLI command '{:?}': {}.",
self.command, e
))
})?;

let stdin = child.stdin.take().ok_or_else(|| {
ProviderError::RequestFailed("Failed to capture stdin".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ProviderError::RequestFailed("Failed to capture stdout".to_string())
})?;

// Drain stderr concurrently to prevent pipe buffer deadlock
let stderr = child.stderr.take();
let stderr_handle = tokio::spawn(async move {
let mut output = String::new();
if let Some(mut stderr) = stderr {
use tokio::io::AsyncReadExt;
let _ = stderr.read_to_string(&mut output).await;
}
output
});

Ok::<_, ProviderError>(tokio::sync::Mutex::new(CliProcess {
child,
stdin,
reader: BufReader::new(stdout),
stderr_handle,
messages_sent: 0,
}))
})
.await?;
let needs_spawn = match guard.as_mut() {
None => true,
Some(p) => !Self::is_process_alive(p),
};
if needs_spawn {
*guard = Some(self.spawn_process(&filtered_system)?);
}

let mut process = process_mutex.lock().await;
let new_messages = self.content_blocks_for(guard.as_ref().unwrap(), messages);
let process = guard.as_mut().unwrap();
match process.send_and_read(&new_messages).await {
Ok(lines) => {
process.messages_sent = messages.len();
tracing::debug!("Command executed successfully, got {} lines", lines.len());
Ok(lines)
}
Err(e) if Self::is_recoverable(&e) => {
tracing::debug!("CLI process dead, respawning");
let process = guard.insert(self.spawn_process(&filtered_system)?);
let new_messages = self.content_blocks_for(process, messages);
let lines = process.send_and_read(&new_messages).await?;
process.messages_sent = messages.len();
tracing::debug!(
"Command executed successfully after respawn, got {} lines",
lines.len()
);
Ok(lines)
}
Err(e) => Err(e),
}
}

// Build content from new messages only (skip already-sent ones).
// If messages is shorter than messages_sent, the caller started a fresh
// conversation on the same provider instance — send everything.
fn content_blocks_for(&self, process: &CliProcess, messages: &[Message]) -> Vec<Value> {
let new_messages = if process.messages_sent > 0 && process.messages_sent < messages.len() {
&messages[process.messages_sent..]
} else {
messages
};
let new_blocks = self.messages_to_content_blocks(new_messages);

// Write NDJSON line to stdin
let ndjson_line = build_stream_json_input(&new_blocks);
process
.stdin
.write_all(ndjson_line.as_bytes())
.await
.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e))
})?;
process.stdin.write_all(b"\n").await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e))
})?;

// Read lines until we see a "result" event
let mut lines = Vec::new();
let mut line = String::new();

loop {
line.clear();
match process.reader.read_line(&mut line).await {
Ok(0) => {
// EOF means the process died
return Err(ProviderError::RequestFailed(
"Claude CLI process terminated unexpectedly".to_string(),
));
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
lines.push(trimmed.to_string());
self.messages_to_content_blocks(new_messages)
}

// Check if this is a result event (end of turn)
if let Ok(parsed) = serde_json::from_str::<Value>(trimmed) {
match parsed.get("type").and_then(|t| t.as_str()) {
Some("result") => break,
Some("error") => break,
_ => {}
}
}
}
Err(e) => {
return Err(ProviderError::RequestFailed(format!(
"Failed to read output: {}",
e
)));
}
fn is_recoverable(err: &ProviderError) -> bool {
match err {
ProviderError::RequestFailed(msg) => {
msg.contains("Broken pipe") || msg.contains("terminated unexpectedly")
}
_ => false,
}

// Update messages_sent for next turn
process.messages_sent = messages.len();

tracing::debug!("Command executed successfully, got {} lines", lines.len());
for (i, line) in lines.iter().enumerate() {
tracing::debug!("Line {}: {}", i, line);
}

Ok(lines)
}

/// Generate a simple session description without calling subprocess
Expand Down Expand Up @@ -734,12 +767,30 @@ mod tests {
);
}

#[test]
fn test_is_recoverable() {
assert!(ClaudeCodeProvider::is_recoverable(
&ProviderError::RequestFailed(
"Failed to write to stdin: Broken pipe (os error 32)".into()
)
));
assert!(ClaudeCodeProvider::is_recoverable(
&ProviderError::RequestFailed("Claude CLI process terminated unexpectedly".into())
));
assert!(!ClaudeCodeProvider::is_recoverable(
&ProviderError::RequestFailed("Failed to read output: connection reset".into())
));
assert!(!ClaudeCodeProvider::is_recoverable(
&ProviderError::Authentication("Broken pipe".into()) // wrong variant
));
}

fn make_provider() -> ClaudeCodeProvider {
ClaudeCodeProvider {
command: PathBuf::from("claude"),
model: ModelConfig::new("sonnet").unwrap(),
name: "claude-code".to_string(),
cli_process: tokio::sync::OnceCell::new(),
cli_process: tokio::sync::Mutex::new(None),
}
}
}