diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 2a7aeac8f1ec..976666eb7fd9 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -3,11 +3,14 @@ use async_trait::async_trait; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader}; use tokio::process::Command; -use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, + ProviderUsage, Usage, +}; use super::cli_common::{error_from_event, extract_usage_tokens}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; @@ -18,6 +21,7 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use crate::subprocess::configure_subprocess; +use async_stream::try_stream; use futures::future::BoxFuture; use rmcp::model::Role; use rmcp::model::Tool; @@ -39,7 +43,7 @@ pub struct GeminiCliProvider { #[serde(skip)] name: String, #[serde(skip)] - cli_session_id: OnceLock, + cli_session_id: Arc>, } impl GeminiCliProvider { @@ -52,7 +56,7 @@ impl GeminiCliProvider { command: resolved_command, model, name: GEMINI_CLI_PROVIDER_NAME.to_string(), - cli_session_id: OnceLock::new(), + cli_session_id: Arc::new(OnceLock::new()), }) } @@ -119,13 +123,18 @@ impl GeminiCliProvider { cmd } - async fn execute_command( + fn spawn_command( &self, system: &str, messages: &[Message], - _tools: &[Tool], model_name: &str, - ) -> Result, ProviderError> { + ) -> Result< + ( + tokio::process::Child, + BufReader, + ), + ProviderError, + > { let prompt = self.build_prompt(system, messages); tracing::debug!(command = ?self.command, "Executing Gemini CLI command"); @@ -145,6 +154,18 @@ impl GeminiCliProvider { .take() .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; + Ok((child, BufReader::new(stdout))) + } + + async fn execute_command( + &self, + system: &str, + messages: &[Message], + _tools: &[Tool], + model_name: &str, + ) -> Result, ProviderError> { + let (mut child, mut reader) = self.spawn_command(system, messages, model_name)?; + // Drain stderr concurrently to avoid pipe deadlock let stderr_task = tokio::spawn(async move { let mut buf = String::new(); @@ -154,7 +175,6 @@ impl GeminiCliProvider { (child, buf) }); - let mut reader = BufReader::new(stdout); let mut events = Vec::new(); let mut line = String::new(); @@ -353,6 +373,129 @@ impl Provider for GeminiCliProvider { ProviderUsage::new(model_config.model_name.clone(), usage), )) } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + _session_id: &str, + system: &str, + messages: &[Message], + _tools: &[Tool], + ) -> Result { + if super::cli_common::is_session_description_request(system) { + let (message, usage) = super::cli_common::generate_simple_session_description( + &self.model.model_name, + messages, + )?; + return Ok(stream_from_single_message(message, usage)); + } + + let (mut child, mut reader) = + self.spawn_command(system, messages, &self.model.model_name)?; + let session_id_lock = Arc::clone(&self.cli_session_id); + let model_name = self.model.model_name.clone(); + let message_id = uuid::Uuid::new_v4().to_string(); + + // Drain stderr concurrently to avoid pipe deadlock + let stderr = child.stderr.take(); + let stderr_drain = tokio::spawn(async move { + let mut buf = String::new(); + if let Some(mut stderr) = stderr { + let _ = AsyncReadExt::read_to_string(&mut stderr, &mut buf).await; + } + buf + }); + + Ok(Box::pin(try_stream! { + let mut line = String::new(); + let mut accumulated_usage = Usage::default(); + let stream_timestamp = chrono::Utc::now().timestamp(); + + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + + if let Ok(parsed) = serde_json::from_str::(trimmed) { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("init") => { + if let Some(sid) = + parsed.get("session_id").and_then(|s| s.as_str()) + { + let _ = session_id_lock.set(sid.to_string()); + } + } + Some("message") => { + let is_assistant = parsed.get("role").and_then(|r| r.as_str()) + == Some("assistant"); + let content = parsed + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + if is_assistant && !content.is_empty() { + let mut partial = Message::new( + Role::Assistant, + stream_timestamp, + vec![MessageContent::text(content)], + ); + partial.id = Some(message_id.clone()); + yield (Some(partial), None); + } + } + Some("result") => { + if let Some(stats) = parsed.get("stats") { + accumulated_usage = extract_usage_tokens(stats); + } + break; + } + Some("error") => { + let _ = child.wait().await; + Err(error_from_event("Gemini CLI", &parsed))?; + } + _ => {} + } + } else { + tracing::warn!(line = trimmed, "Non-JSON line in stream-json output"); + } + } + Err(e) => { + let _ = child.wait().await; + Err(ProviderError::RequestFailed(format!( + "Failed to read streaming output: {e}" + )))?; + } + } + } + + let stderr_text = stderr_drain.await.unwrap_or_default(); + let exit_status = child.wait().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to wait for command: {e}")) + })?; + + if !exit_status.success() { + let stderr_snippet = stderr_text.trim(); + let detail = if stderr_snippet.is_empty() { + format!("exit code {:?}", exit_status.code()) + } else { + format!("exit code {:?}: {stderr_snippet}", exit_status.code()) + }; + Err(ProviderError::RequestFailed(format!( + "Gemini CLI command failed ({detail})" + )))?; + } + + let provider_usage = ProviderUsage::new(model_name, accumulated_usage); + yield (None, Some(provider_usage)); + })) + } } #[cfg(test)] @@ -365,7 +508,7 @@ mod tests { command: PathBuf::from("gemini"), model: ModelConfig::new("gemini-2.5-pro").unwrap(), name: "gemini-cli".to_string(), - cli_session_id: OnceLock::new(), + cli_session_id: Arc::new(OnceLock::new()), } }