Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 152 additions & 9 deletions crates/goose/src/providers/gemini_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use async_trait::async_trait;
use serde_json::{json, Value};
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use tokio::process::Command;

use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage};
use super::base::{
stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata,
ProviderUsage, Usage,
};
use super::cli_common::{error_from_event, extract_usage_tokens};
use super::errors::ProviderError;
use super::utils::{filter_extensions_from_system_prompt, RequestLog};
Expand All @@ -18,6 +21,7 @@ use crate::conversation::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::base::ConfigKey;
use crate::subprocess::configure_subprocess;
use async_stream::try_stream;
use futures::future::BoxFuture;
use rmcp::model::Role;
use rmcp::model::Tool;
Expand All @@ -39,7 +43,7 @@ pub struct GeminiCliProvider {
#[serde(skip)]
name: String,
#[serde(skip)]
cli_session_id: OnceLock<String>,
cli_session_id: Arc<OnceLock<String>>,
}

impl GeminiCliProvider {
Expand All @@ -52,7 +56,7 @@ impl GeminiCliProvider {
command: resolved_command,
model,
name: GEMINI_CLI_PROVIDER_NAME.to_string(),
cli_session_id: OnceLock::new(),
cli_session_id: Arc::new(OnceLock::new()),
})
}

Expand Down Expand Up @@ -119,13 +123,18 @@ impl GeminiCliProvider {
cmd
}

async fn execute_command(
fn spawn_command(
&self,
system: &str,
messages: &[Message],
_tools: &[Tool],
model_name: &str,
) -> Result<Vec<Value>, ProviderError> {
) -> Result<
(
tokio::process::Child,
BufReader<tokio::process::ChildStdout>,
),
ProviderError,
> {
let prompt = self.build_prompt(system, messages);

tracing::debug!(command = ?self.command, "Executing Gemini CLI command");
Expand All @@ -145,6 +154,18 @@ impl GeminiCliProvider {
.take()
.ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;

Ok((child, BufReader::new(stdout)))
}

async fn execute_command(
&self,
system: &str,
messages: &[Message],
_tools: &[Tool],
model_name: &str,
) -> Result<Vec<Value>, ProviderError> {
let (mut child, mut reader) = self.spawn_command(system, messages, model_name)?;

// Drain stderr concurrently to avoid pipe deadlock
let stderr_task = tokio::spawn(async move {
let mut buf = String::new();
Expand All @@ -154,7 +175,6 @@ impl GeminiCliProvider {
(child, buf)
});

let mut reader = BufReader::new(stdout);
let mut events = Vec::new();
let mut line = String::new();

Expand Down Expand Up @@ -353,6 +373,129 @@ impl Provider for GeminiCliProvider {
ProviderUsage::new(model_config.model_name.clone(), usage),
))
}

fn supports_streaming(&self) -> bool {
true
}

async fn stream(
&self,
_session_id: &str,
system: &str,
messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
if super::cli_common::is_session_description_request(system) {
let (message, usage) = super::cli_common::generate_simple_session_description(
&self.model.model_name,
messages,
)?;
return Ok(stream_from_single_message(message, usage));
}

let (mut child, mut reader) =
self.spawn_command(system, messages, &self.model.model_name)?;
let session_id_lock = Arc::clone(&self.cli_session_id);
let model_name = self.model.model_name.clone();
let message_id = uuid::Uuid::new_v4().to_string();

// Drain stderr concurrently to avoid pipe deadlock
let stderr = child.stderr.take();
let stderr_drain = tokio::spawn(async move {
let mut buf = String::new();
if let Some(mut stderr) = stderr {
let _ = AsyncReadExt::read_to_string(&mut stderr, &mut buf).await;
}
buf
});

Ok(Box::pin(try_stream! {
let mut line = String::new();
let mut accumulated_usage = Usage::default();
let stream_timestamp = chrono::Utc::now().timestamp();

loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}

if let Ok(parsed) = serde_json::from_str::<Value>(trimmed) {
match parsed.get("type").and_then(|t| t.as_str()) {
Some("init") => {
if let Some(sid) =
parsed.get("session_id").and_then(|s| s.as_str())
{
let _ = session_id_lock.set(sid.to_string());
}
}
Some("message") => {
let is_assistant = parsed.get("role").and_then(|r| r.as_str())
== Some("assistant");
let content = parsed
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("");
if is_assistant && !content.is_empty() {
let mut partial = Message::new(
Role::Assistant,
stream_timestamp,
vec![MessageContent::text(content)],
);
partial.id = Some(message_id.clone());
yield (Some(partial), None);
}
}
Some("result") => {
if let Some(stats) = parsed.get("stats") {
accumulated_usage = extract_usage_tokens(stats);
}
break;
}
Some("error") => {
let _ = child.wait().await;
Err(error_from_event("Gemini CLI", &parsed))?;
}
_ => {}
}
} else {
tracing::warn!(line = trimmed, "Non-JSON line in stream-json output");
}
}
Err(e) => {
let _ = child.wait().await;
Err(ProviderError::RequestFailed(format!(
"Failed to read streaming output: {e}"
)))?;
}
}
}

let stderr_text = stderr_drain.await.unwrap_or_default();
let exit_status = child.wait().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to wait for command: {e}"))
})?;

if !exit_status.success() {
let stderr_snippet = stderr_text.trim();
let detail = if stderr_snippet.is_empty() {
format!("exit code {:?}", exit_status.code())
} else {
format!("exit code {:?}: {stderr_snippet}", exit_status.code())
};
Err(ProviderError::RequestFailed(format!(
"Gemini CLI command failed ({detail})"
)))?;
}

let provider_usage = ProviderUsage::new(model_name, accumulated_usage);
yield (None, Some(provider_usage));
}))
}
}

#[cfg(test)]
Expand All @@ -365,7 +508,7 @@ mod tests {
command: PathBuf::from("gemini"),
model: ModelConfig::new("gemini-2.5-pro").unwrap(),
name: "gemini-cli".to_string(),
cli_session_id: OnceLock::new(),
cli_session_id: Arc::new(OnceLock::new()),
}
}

Expand Down