diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index fcc22dc4b0b4..67081b6369a4 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -981,6 +981,10 @@ impl CliSession { if permission == Permission::Cancel { output::render_text("Tool call cancelled. Returning to chat...", Some(Color::Yellow), true); + self.agent.handle_confirmation(id.clone(), PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: Permission::DenyOnce, + }).await; let mut response_message = Message::user(); response_message.content.push(MessageContent::tool_response( id, diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 858adb811b95..e4134dd79d4a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -39,7 +39,7 @@ use crate::mcp_utils::ToolResult; use crate::permission::permission_inspector::PermissionInspector; use crate::permission::permission_judge::PermissionCheckResult; use crate::permission::PermissionConfirmation; -use crate::providers::base::Provider; +use crate::providers::base::{PermissionRouting, Provider}; use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe, Response, Settings}; use crate::scheduler_trait::SchedulerTrait; @@ -846,11 +846,28 @@ impl Agent { request_id: String, confirmation: PermissionConfirmation, ) { + let provider = self.provider.lock().await.clone(); + if let Some(provider) = provider.as_ref() { + if provider.permission_routing() == PermissionRouting::ActionRequired + && provider + .handle_permission_confirmation(&request_id, &confirmation) + .await + { + return; + } + } if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await { error!("Failed to send confirmation: {}", e); } } + pub async fn supports_action_required_permissions(&self) -> bool { + if let Some(provider) = self.provider.lock().await.as_ref() { + return provider.permission_routing() == PermissionRouting::ActionRequired; + } + false + } + #[instrument( skip(self, user_message, session_config), fields(user_message, trace_input) @@ -2014,8 +2031,119 @@ impl Agent { #[cfg(test)] mod tests { use super::*; + use crate::permission::permission_confirmation::PrincipalType; + use crate::providers::base::PermissionRouting; use crate::recipe::Response; + struct ActionRequiredProvider { + handled: tokio::sync::Mutex>, + } + + impl ActionRequiredProvider { + fn new() -> Self { + Self { + handled: tokio::sync::Mutex::new(Vec::new()), + } + } + } + + impl std::fmt::Debug for ActionRequiredProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ActionRequiredProvider").finish() + } + } + + #[async_trait::async_trait] + impl crate::providers::base::Provider for ActionRequiredProvider { + fn get_name(&self) -> &str { + "test-action-required" + } + fn get_model_config(&self) -> crate::model::ModelConfig { + crate::model::ModelConfig::new("test").unwrap() + } + async fn stream( + &self, + _: &crate::model::ModelConfig, + _: &str, + _: &str, + _: &[crate::conversation::message::Message], + _: &[rmcp::model::Tool], + ) -> Result + { + unimplemented!() + } + fn permission_routing(&self) -> PermissionRouting { + PermissionRouting::ActionRequired + } + async fn handle_permission_confirmation( + &self, + request_id: &str, + confirmation: &PermissionConfirmation, + ) -> bool { + self.handled + .lock() + .await + .push((request_id.to_string(), confirmation.clone())); + request_id == "known" + } + } + + #[tokio::test] + async fn test_handle_confirmation_routes_to_provider() { + let agent = Agent::new(); + let provider = Arc::new(ActionRequiredProvider::new()); + *agent.provider.lock().await = + Some(provider.clone() as Arc); + + // Known request_id → provider handles it, confirmation_tx NOT called + agent + .handle_confirmation( + "known".to_string(), + PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: crate::permission::Permission::AllowOnce, + }, + ) + .await; + assert_eq!(provider.handled.lock().await.len(), 1); + + // Unknown request_id → provider returns false, falls through to confirmation_tx + agent + .handle_confirmation( + "unknown".to_string(), + PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: crate::permission::Permission::DenyOnce, + }, + ) + .await; + assert_eq!(provider.handled.lock().await.len(), 2); + // Verify the fallthrough went to confirmation_rx + let mut rx = agent.confirmation_rx.lock().await; + let (id, conf) = rx.recv().await.unwrap(); + assert_eq!(id, "unknown"); + assert_eq!(conf.permission, crate::permission::Permission::DenyOnce); + } + + #[tokio::test] + async fn test_handle_confirmation_noop_provider() { + let agent = Agent::new(); + // No provider set → Noop routing, goes straight to confirmation_tx + agent + .handle_confirmation( + "any".to_string(), + PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: crate::permission::Permission::AllowOnce, + }, + ) + .await; + + let mut rx = agent.confirmation_rx.lock().await; + let (id, _) = rx.recv().await.unwrap(); + assert_eq!(id, "any"); + } + #[tokio::test] async fn test_add_final_output_tool() -> Result<()> { let agent = Agent::new(); diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index ea498a61f0b5..e82aaabeb466 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -12,6 +12,7 @@ use crate::config::ExtensionConfig; use crate::conversation::message::{Message, MessageContent}; use crate::conversation::Conversation; use crate::model::ModelConfig; +use crate::permission::PermissionConfirmation; use crate::utils::safe_truncate; use rmcp::model::Tool; use utoipa::ToSchema; @@ -432,6 +433,12 @@ pub trait ProviderDef: Send + Sync { Self: Sized; } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum PermissionRouting { + ActionRequired, + Noop, +} + /// Trait for LeadWorkerProvider-specific functionality pub trait LeadWorkerProviderTrait { /// Get information about the lead and worker models for logging @@ -691,6 +698,18 @@ pub trait Provider: Send + Sync { "OAuth configuration not supported by this provider".to_string(), )) } + + fn permission_routing(&self) -> PermissionRouting { + PermissionRouting::Noop + } + + async fn handle_permission_confirmation( + &self, + _request_id: &str, + _confirmation: &PermissionConfirmation, + ) -> bool { + false + } } /// A message stream yields partial text content but complete tool calls, all within the Message object diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index a46ff5534e99..c36f866cf8df 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -3,18 +3,21 @@ use async_stream::try_stream; use async_trait::async_trait; use futures::future::BoxFuture; use rmcp::model::{Role, Tool}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use std::collections::HashMap; use std::io::Write; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::sync::Arc; use tempfile::NamedTempFile; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::process::Command; +use tokio::sync::oneshot; use super::base::{ - stream_from_single_message, ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, + stream_from_single_message, ConfigKey, MessageStream, PermissionRouting, Provider, ProviderDef, + ProviderMetadata, ProviderUsage, Usage, }; use super::errors::ProviderError; use super::utils::filter_extensions_from_system_prompt; @@ -24,6 +27,8 @@ use crate::config::search_path::SearchPaths; use crate::config::{Config, ExtensionConfig, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; +use crate::permission::permission_confirmation::PrincipalType; +use crate::permission::{Permission, PermissionConfirmation}; use crate::subprocess::configure_subprocess; use super::cli_common::{error_from_event, extract_usage_tokens}; @@ -32,6 +37,116 @@ const CLAUDE_CODE_PROVIDER_NAME: &str = "claude-code"; pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "default"; pub const CLAUDE_CODE_DOC_URL: &str = "https://code.claude.com/docs/en/setup"; +// https://github.com/anthropics/claude-agent-sdk-python/blob/0e9397e/src/claude_agent_sdk/types.py#L857-L859 +#[derive(Serialize)] +struct ControlResponse { + #[serde(rename = "type")] + msg_type: &'static str, + response: ControlResponseBody, +} + +#[derive(Serialize)] +struct ControlResponseBody { + subtype: &'static str, + request_id: String, + response: T, +} + +// https://github.com/anthropics/claude-agent-sdk-python/blob/0e9397e/src/claude_agent_sdk/types.py#L135-L153 +#[derive(Serialize)] +#[serde(tag = "behavior")] +enum PermissionResponse { + #[serde(rename = "allow")] + Allow { + #[serde(rename = "updatedInput")] + updated_input: serde_json::Map, + #[serde(rename = "toolUseID")] + tool_use_id: String, + }, + #[serde(rename = "deny")] + Deny { message: String }, +} + +#[derive(Serialize)] +struct ControlRequest { + #[serde(rename = "type")] + msg_type: &'static str, + request_id: String, + request: ControlRequestBody, +} + +#[derive(Serialize)] +#[serde(tag = "subtype")] +enum ControlRequestBody { + #[serde(rename = "initialize")] + Initialize, + #[serde(rename = "set_model")] + SetModel { model: String }, +} + +impl ControlRequestBody { + fn label(&self) -> &'static str { + match self { + Self::Initialize => "initialize", + Self::SetModel { .. } => "set_model", + } + } +} + +#[derive(Deserialize)] +struct IncomingControlResponse { + response: IncomingControlResponseBody, +} + +#[derive(Deserialize)] +#[serde(tag = "subtype")] +enum IncomingControlResponseBody { + #[serde(rename = "success")] + Success { + request_id: String, + #[serde(default)] + response: Option, + }, + #[serde(rename = "error")] + Error { + request_id: String, + #[serde(default)] + error: String, + }, +} + +#[derive(Deserialize)] +struct IncomingControlRequest { + request_id: String, + request: IncomingRequestBody, +} + +#[derive(Deserialize)] +#[serde(tag = "subtype")] +enum IncomingRequestBody { + #[serde(rename = "can_use_tool")] + CanUseTool { + tool_name: String, + #[serde(default)] + input: serde_json::Map, + #[serde(default)] + tool_use_id: String, + }, +} + +impl ControlResponse { + fn success(request_id: String, response: T) -> Self { + Self { + msg_type: "control_response", + response: ControlResponseBody { + subtype: "success", + request_id, + response, + }, + } + } +} + struct CliProcess { child: tokio::process::Child, stdin: Box, @@ -60,81 +175,25 @@ impl CliProcess { format!("req_{id}") } - /// Send a `set_model` control request and wait for the response before returning. - /// Skips the request if the model is already active. + async fn send_control_request( + &mut self, + body: ControlRequestBody, + ) -> Result, ProviderError> { + let request_id = self.next_request_id(); + exchange_control(&mut self.stdin, &mut self.reader, &request_id, body).await + } + async fn send_set_model(&mut self, model: &str) -> Result<(), ProviderError> { if model == self.current_model { return Ok(()); } - - let request_id = self.next_request_id(); - let req = json!({ - "type": "control_request", - "request_id": request_id, - "request": {"subtype": "set_model", "model": model} - }); - let mut req_str = serde_json::to_string(&req).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to serialize set_model request: {e}")) - })?; - req_str.push('\n'); - self.stdin - .write_all(req_str.as_bytes()) - .await - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write set_model request: {e}")) - })?; - - // Read lines until we get the control_response for our request. - let mut line = String::new(); - loop { - line.clear(); - match self.reader.read_line(&mut line).await { - Ok(0) => { - return Err(ProviderError::RequestFailed( - "CLI process terminated while waiting for set_model response".to_string(), - )); - } - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - if let Ok(parsed) = serde_json::from_str::(trimmed) { - if parsed.get("type").and_then(|t| t.as_str()) == Some("control_response") { - // Skip responses that don't match our request_id - if parsed - .pointer("/response/request_id") - .and_then(|id| id.as_str()) - != Some(request_id.as_str()) - { - continue; - } - let success = - parsed.pointer("/response/subtype").and_then(|s| s.as_str()) - == Some("success"); - if success { - self.current_model = model.to_string(); - self.log_model_update = true; - return Ok(()); - } else { - let err = parsed - .pointer("/response/error") - .and_then(|e| e.as_str()) - .unwrap_or("unknown"); - return Err(ProviderError::RequestFailed(format!( - "set_model failed: {err}" - ))); - } - } - } - } - Err(e) => { - return Err(ProviderError::RequestFailed(format!( - "Failed to read set_model response: {e}" - ))); - } - } - } + self.send_control_request(ControlRequestBody::SetModel { + model: model.to_string(), + }) + .await?; + self.current_model = model.to_string(); + self.log_model_update = true; + Ok(()) } async fn drain_pending_response(&mut self) { @@ -208,6 +267,9 @@ pub struct ClaudeCodeProvider { mcp_config_file: Option, #[serde(skip)] cli_process: tokio::sync::OnceCell>>, + #[serde(skip)] + pending_confirmations: + Arc>>>, } impl ClaudeCodeProvider { @@ -282,32 +344,29 @@ impl ClaudeCodeProvider { cmd } - fn apply_permission_flags(cmd: &mut Command) -> Result<(), ProviderError> { + /// Returns true when the control protocol is enabled (Approve mode). + fn apply_permission_flags(cmd: &mut Command) -> Result { let config = Config::global(); let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); match goose_mode { GooseMode::Auto => { cmd.arg("--dangerously-skip-permissions"); + Ok(false) } GooseMode::SmartApprove => { cmd.arg("--permission-mode").arg("acceptEdits"); + Ok(false) } GooseMode::Approve => { - return Err(ProviderError::RequestFailed( - "\n\n\n### NOTE\n\n\n \ - Claude Code CLI provider does not support Approve mode.\n \ - Please use Auto (which will run anything it needs to) or \ - SmartApprove (most things will run or Chat Mode)\n\n\n" - .to_string(), - )); + cmd.arg("--permission-prompt-tool").arg("stdio"); + Ok(true) } - GooseMode::Chat => {} + GooseMode::Chat => Ok(false), } - Ok(()) } - fn spawn_process(&self, filtered_system: &str) -> Result { + async fn spawn_process(&self, filtered_system: &str) -> Result { let mut cmd = self.build_stream_json_command(); if let Some(f) = &self.mcp_config_file { @@ -321,7 +380,7 @@ impl ClaudeCodeProvider { .arg("--model") .arg(&self.model.model_name); - Self::apply_permission_flags(&mut cmd)?; + let control_protocol_enabled = Self::apply_permission_flags(&mut cmd)?; let mut child = cmd.spawn().map_err(|e| { ProviderError::RequestFailed(format!( @@ -349,7 +408,7 @@ impl ClaudeCodeProvider { output }); - Ok(CliProcess { + let mut process = CliProcess { child, stdin: Box::new(stdin), reader: BufReader::new(Box::new(stdout)), @@ -358,7 +417,15 @@ impl ClaudeCodeProvider { log_model_update: false, next_request_id: 0, needs_drain: false, - }) + }; + + if control_protocol_enabled { + process + .send_control_request(ControlRequestBody::Initialize) + .await?; + } + + Ok(process) } async fn get_or_init_process( @@ -368,37 +435,83 @@ impl ClaudeCodeProvider { self.cli_process .get_or_try_init(|| async { Ok(Arc::new(tokio::sync::Mutex::new( - self.spawn_process(filtered_system)?, + self.spawn_process(filtered_system).await?, ))) }) .await } } -/// Extract model aliases from the CLI's initialize control_response. -fn parse_models_from_lines(lines: &[String]) -> Vec { - for line in lines { - if let Ok(parsed) = serde_json::from_str::(line) { - if parsed.get("type").and_then(|t| t.as_str()) != Some("control_response") { - continue; +async fn exchange_control( + stdin: &mut (impl AsyncWrite + Unpin), + reader: &mut (impl AsyncBufRead + Unpin), + request_id: &str, + body: ControlRequestBody, +) -> Result, ProviderError> { + let label = body.label(); + let req = ControlRequest { + msg_type: "control_request", + request_id: request_id.to_string(), + request: body, + }; + let mut req_str = serde_json::to_string(&req).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to serialize {label} request: {e}")) + })?; + req_str.push('\n'); + stdin.write_all(req_str.as_bytes()).await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write {label} request: {e}")) + })?; + + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => { + return Err(ProviderError::RequestFailed(format!( + "CLI process terminated while waiting for {label} response" + ))); } - let success = - parsed.pointer("/response/subtype").and_then(|s| s.as_str()) == Some("success"); - if !success { - continue; + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + if let Ok(msg) = serde_json::from_str::(trimmed) { + match msg.response { + IncomingControlResponseBody::Success { + request_id: ref rid, + response, + } if rid == request_id => return Ok(response), + IncomingControlResponseBody::Error { + request_id: ref rid, + error, + } if rid == request_id => { + return Err(ProviderError::RequestFailed(format!( + "{label} failed: {error}" + ))); + } + _ => continue, + } + } } - if let Some(models) = parsed - .pointer("/response/response/models") - .and_then(|m| m.as_array()) - { - return models - .iter() - .filter_map(|m| m.get("value").and_then(|v| v.as_str()).map(String::from)) - .collect(); + Err(e) => { + return Err(ProviderError::RequestFailed(format!( + "Failed to read {label} response: {e}" + ))); } } } - Vec::new() +} + +fn extract_model_aliases(response: Option<&Value>) -> Vec { + response + .and_then(|v| v.get("models")?.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|m| m.get("value")?.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default() } fn build_stream_json_input(content_blocks: &[Value], session_id: &str) -> String { @@ -513,6 +626,7 @@ impl ProviderDef for ClaudeCodeProvider { name: CLAUDE_CODE_PROVIDER_NAME.to_string(), mcp_config_file, cli_process: tokio::sync::OnceCell::new(), + pending_confirmations: Arc::new(tokio::sync::Mutex::new(HashMap::new())), }) }) } @@ -547,46 +661,33 @@ impl Provider for ClaudeCodeProvider { .take() .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; - let request = json!({ - "type": "control_request", - "request_id": "model_list", - "request": {"subtype": "initialize"} - }); - let mut request_str = serde_json::to_string(&request).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to serialize initialize request: {e}")) - })?; - request_str.push('\n'); - stdin.write_all(request_str.as_bytes()).await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write initialize request: {e}")) - })?; - let mut reader = BufReader::new(stdout); - let mut lines = Vec::new(); - let mut line = String::new(); - - // Read until we see a control_response or hit EOF - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - lines.push(trimmed.to_string()); - if let Ok(parsed) = serde_json::from_str::(trimmed) { - if parsed.get("type").and_then(|t| t.as_str()) == Some("control_response") { - break; - } - } - } - Err(_) => break, - } - } - + let response = exchange_control( + &mut stdin, + &mut reader, + "model_list", + ControlRequestBody::Initialize, + ) + .await; let _ = child.kill().await; - Ok(parse_models_from_lines(&lines)) + Ok(extract_model_aliases(response.ok().flatten().as_ref())) + } + + fn permission_routing(&self) -> PermissionRouting { + PermissionRouting::ActionRequired + } + + async fn handle_permission_confirmation( + &self, + request_id: &str, + confirmation: &PermissionConfirmation, + ) -> bool { + let mut pending = self.pending_confirmations.lock().await; + if let Some(tx) = pending.remove(request_id) { + let _ = tx.send(confirmation.clone()); + return true; + } + false } async fn stream( @@ -613,11 +714,30 @@ impl Provider for ClaudeCodeProvider { let ndjson_line = build_stream_json_input(&blocks, session_id); let model_name = model_config.model_name.clone(); let message_id = uuid::Uuid::new_v4().to_string(); + let pending_confirmations = Arc::clone(&self.pending_confirmations); Ok(Box::pin(try_stream! { // Single lock acquisition covers write-to-stdin and read-from-stdout, // eliminating the race window between the two. let mut process = process_arc.lock_owned().await; + + // Clean up pending permissions from a cancelled stream + { + let mut pending = pending_confirmations.lock().await; + for (req_id, tx) in pending.drain() { + drop(tx); + let resp = ControlResponse::success( + req_id, + PermissionResponse::Deny { message: "Stream cancelled".to_string() }, + ); + let mut s = serde_json::to_string(&resp).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to serialize cleanup deny response: {e}")) + })?; + s.push('\n'); + let _ = process.stdin.write_all(s.as_bytes()).await; + } + } + process.drain_pending_response().await; process.send_set_model(&model_name).await?; @@ -719,6 +839,49 @@ impl Provider for ClaudeCodeProvider { stream_error = Some(error_from_event("Claude CLI", &parsed)); break; } + Some("control_request") => { + if let Ok(IncomingControlRequest { + request_id, + request: IncomingRequestBody::CanUseTool { tool_name, input, tool_use_id }, + }) = serde_json::from_str::(trimmed) { + tracing::debug!(raw = %parsed, "can_use_tool control_request received"); + + let (tx, rx) = oneshot::channel(); + pending_confirmations.lock().await.insert(request_id.clone(), tx); + + let action_msg = Message::assistant().with_action_required( + request_id.clone(), tool_name, input.clone(), None, + ); + yield (Some(action_msg), None); + + let confirmation = rx.await.unwrap_or(PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: Permission::Cancel, + }); + pending_confirmations.lock().await.remove(&request_id); + + let perm_resp = match confirmation.permission { + Permission::AlwaysAllow | Permission::AllowOnce => { + PermissionResponse::Allow { + updated_input: input, + tool_use_id, + } + } + _ => PermissionResponse::Deny { + message: "User denied the tool call".to_string(), + }, + }; + let resp = ControlResponse::success(request_id, perm_resp); + let mut resp_str = serde_json::to_string(&resp).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to serialize permission response: {e}")) + })?; + tracing::debug!(json = %resp_str, "can_use_tool control_response sent"); + resp_str.push('\n'); + process.stdin.write_all(resp_str.as_bytes()).await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write permission response: {e}")) + })?; + } + } Some("system") if process.log_model_update => { if let Some(resolved) = parsed.get("model").and_then(|m| m.as_str()) { tracing::debug!( @@ -902,34 +1065,27 @@ mod tests { } #[test_case( - &[ - r#"{"type":"control_response","response":{"subtype":"success","request_id":"model_list","response":{"models":[{"value":"default","displayName":"Default (recommended)","description":"Opus 4.6 · Most capable for complex work"},{"value":"sonnet","displayName":"Sonnet","description":"Sonnet 4.5 · Best for everyday tasks"},{"value":"haiku","displayName":"Haiku","description":"Haiku 4.5 · Fastest for quick answers"}]}}}"#, - ], - &["default", "sonnet", "haiku"] + Some(json!({"models":[{"value":"default","displayName":"Default"},{"value":"sonnet","displayName":"Sonnet"},{"value":"haiku","displayName":"Haiku"}]})), + vec!["default".into(), "sonnet".into(), "haiku".into()] ; "success" )] #[test_case( - &[ - r#"{"type":"control_response","response":{"subtype":"success","request_id":"model_list","response":{"models":[{"value":"default","displayName":"Default","description":"..."},{"value":null,"displayName":"Bad","description":"..."}]}}}"#, - ], - &["default"] + Some(json!({"models":[{"value":"default","displayName":"Default"},{"value":null,"displayName":"Bad"}]})), + vec!["default".into()] ; "filters_null_values" )] #[test_case( - &[r#"{"type":"system","subtype":"init","session_id":"abc"}"#], - &[] - ; "no_control_response" + None, + vec![] + ; "none_input" )] #[test_case( - &[r#"{"type":"control_response","response":{"subtype":"error","request_id":"req_1","error":"fail"}}"#], - &[] - ; "error_response" + Some(json!({"other":"data"})), + vec![] + ; "no_models_key" )] - fn test_parse_models_from_lines(lines: &[&str], expected: &[&str]) { - let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); - let result = parse_models_from_lines(&lines); - let expected: Vec = expected.iter().map(|s| s.to_string()).collect(); - assert_eq!(result, expected); + fn test_extract_model_aliases(response: Option, expected: Vec) { + assert_eq!(extract_model_aliases(response.as_ref()), expected); } #[test_case( @@ -1044,6 +1200,7 @@ mod tests { name: "claude-code".to_string(), mcp_config_file: None, cli_process: tokio::sync::OnceCell::new(), + pending_confirmations: Arc::new(tokio::sync::Mutex::new(HashMap::new())), } } @@ -1067,6 +1224,43 @@ mod tests { (process, stdin_reader) } + async fn stream_with_canned_stdout( + canned_lines: &[&str], + ) -> (ClaudeCodeProvider, MessageStream, tokio::io::DuplexStream) { + let canned_stdout = canned_lines.join("\n"); + let (process, stdin_reader) = make_test_process(&canned_stdout); + let provider = make_provider(); + let process_arc = Arc::new(tokio::sync::Mutex::new(process)); + provider.cli_process.set(process_arc).unwrap(); + + let messages = vec![Message::user().with_text("test")]; + let stream = provider + .stream(&provider.model, "test-session", "", &messages, &[]) + .await + .unwrap(); + (provider, stream, stdin_reader) + } + + async fn capture_stdin( + provider: &ClaudeCodeProvider, + mut reader: tokio::io::DuplexStream, + ) -> String { + use tokio::io::AsyncReadExt; + provider.cli_process.get().unwrap().lock().await.stdin = Box::new(tokio::io::sink()); + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + String::from_utf8(buf).unwrap() + } + + fn extract_permission_response(stdin_str: &str, request_id: &str) -> Value { + let line = stdin_str + .lines() + .find(|l| l.contains(request_id) && l.contains("control_response")) + .unwrap(); + let json: Value = serde_json::from_str(line).unwrap(); + json.pointer("/response/response").unwrap().clone() + } + #[test_case( &[r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#], Some("default"), "sonnet", @@ -1128,4 +1322,140 @@ mod tests { } assert_eq!(String::from_utf8(stdin_bytes).unwrap(), expected_stdin); } + + #[test_case( + Permission::AllowOnce, + json!({"behavior":"allow","updatedInput":{"path":"foo.txt","content":"hello"},"toolUseID":"tu_1"}) + ; "allow" + )] + #[test_case( + Permission::DenyOnce, + json!({"behavior":"deny","message":"User denied the tool call"}) + ; "deny" + )] + #[tokio::test] + async fn test_can_use_tool(permission: Permission, expected_response: Value) { + use futures::StreamExt; + + let (provider, mut stream, stdin_reader) = stream_with_canned_stdout(&[ + r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#, + r#"{"type":"control_request","request_id":"perm_1","request":{"subtype":"can_use_tool","tool_name":"Write","input":{"path":"foo.txt","content":"hello"},"tool_use_id":"tu_1"}}"#, + r#"{"type":"result","result":"Done","usage":{"input_tokens":10,"output_tokens":5}}"#, + ]).await; + + let (first_msg, _) = stream.next().await.unwrap().unwrap(); + let first_msg = first_msg.unwrap(); + let ar = first_msg + .content + .iter() + .find_map(|c| c.as_action_required()) + .unwrap(); + match &ar.data { + crate::conversation::message::ActionRequiredData::ToolConfirmation { + id, + tool_name, + .. + } => { + assert_eq!(id, "perm_1"); + assert_eq!(tool_name, "Write"); + } + _ => panic!("expected ToolConfirmation"), + } + + let handled = provider + .handle_permission_confirmation( + "perm_1", + &PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: permission.clone(), + }, + ) + .await; + assert!(handled); + assert!(provider.pending_confirmations.lock().await.is_empty()); + + while let Some(item) = stream.next().await { + item.unwrap(); + } + drop(stream); + + let stdin_str = capture_stdin(&provider, stdin_reader).await; + let response_data = extract_permission_response(&stdin_str, "perm_1"); + assert_eq!(response_data, expected_response); + } + + #[tokio::test] + async fn test_can_use_tool_cancel_on_drop() { + use futures::StreamExt; + + let (provider, mut stream, stdin_reader) = stream_with_canned_stdout(&[ + r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#, + r#"{"type":"control_request","request_id":"perm_1","request":{"subtype":"can_use_tool","tool_name":"Write","input":{"path":"foo.txt"},"tool_use_id":"tu_1"}}"#, + r#"{"type":"result","result":"Done","usage":{"input_tokens":10,"output_tokens":5}}"#, + ]).await; + + let pending = Arc::clone(&provider.pending_confirmations); + + let (first_msg, _) = stream.next().await.unwrap().unwrap(); + assert!(first_msg + .unwrap() + .content + .iter() + .any(|c| c.as_action_required().is_some())); + + let tx = pending.lock().await.remove("perm_1").unwrap(); + drop(tx); + + while let Some(item) = stream.next().await { + item.unwrap(); + } + drop(stream); + + let stdin_str = capture_stdin(&provider, stdin_reader).await; + let response_data = extract_permission_response(&stdin_str, "perm_1"); + assert_eq!( + response_data, + json!({"behavior":"deny","message":"User denied the tool call"}) + ); + } + + #[tokio::test] + async fn test_pending_permissions_cleaned_on_new_stream() { + use futures::StreamExt; + + let canned_stdout = [ + r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#, + r#"{"type":"result","result":"Done","usage":{"input_tokens":10,"output_tokens":5}}"#, + ] + .join("\n"); + + let (process, stdin_reader) = make_test_process(&canned_stdout); + let provider = make_provider(); + let process_arc = Arc::new(tokio::sync::Mutex::new(process)); + provider.cli_process.set(process_arc).unwrap(); + + let (tx, _rx) = oneshot::channel(); + provider + .pending_confirmations + .lock() + .await + .insert("stale_1".to_string(), tx); + + let messages = vec![Message::user().with_text("test")]; + let mut stream = provider + .stream(&provider.model, "test-session", "", &messages, &[]) + .await + .unwrap(); + + while let Some(item) = stream.next().await { + item.unwrap(); + } + drop(stream); + + assert!(provider.pending_confirmations.lock().await.is_empty()); + + let stdin_str = capture_stdin(&provider, stdin_reader).await; + let response_data = extract_permission_response(&stdin_str, "stale_1"); + assert_eq!(response_data["behavior"], "deny"); + } } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index c048b5bdab09..2cd9f94110e1 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -1,9 +1,14 @@ use anyhow::Result; use dotenvy::dotenv; +use futures::StreamExt; use goose::agents::extension_manager::ExtensionManagerCapabilities; -use goose::agents::{ExtensionManager, GoosePlatform, PromptManager}; -use goose::config::ExtensionConfig; -use goose::conversation::message::{Message, MessageContent}; +use goose::agents::{ + Agent, AgentConfig, AgentEvent, ExtensionManager, GoosePlatform, PromptManager, SessionConfig, +}; +use goose::config::{ExtensionConfig, GooseMode, PermissionManager}; +use goose::conversation::message::{ActionRequiredData, Message, MessageContent}; +use goose::permission::permission_confirmation::PrincipalType; +use goose::permission::{Permission, PermissionConfirmation}; use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; use goose::providers::azure::AZURE_DEFAULT_MODEL; use goose::providers::base::Provider; @@ -19,7 +24,7 @@ use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use goose::providers::sagemaker_tgi::SAGEMAKER_TGI_DEFAULT_MODEL; use goose::providers::snowflake::SNOWFLAKE_DEFAULT_MODEL; use goose::providers::xai::XAI_DEFAULT_MODEL; -use goose::session::SessionManager; +use goose::session::{SessionManager, SessionType}; use goose_test_support::{ExpectedSessionId, McpFixture, FAKE_CODE, TEST_SESSION_ID}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -94,6 +99,7 @@ struct ProviderTester { extension_manager: Arc, is_cli_provider: bool, model_switch_name: Option, + mcp_extension: ExtensionConfig, } impl ProviderTester { @@ -103,6 +109,7 @@ impl ProviderTester { extension_manager: Arc, is_cli_provider: bool, model_switch_name: Option, + mcp_extension: ExtensionConfig, ) -> Self { Self { provider, @@ -110,6 +117,7 @@ impl ProviderTester { extension_manager, is_cli_provider, model_switch_name, + mcp_extension, } } @@ -376,6 +384,7 @@ impl ProviderTester { } async fn run_test_suite(&self) -> Result<()> { + let _guard = env_lock::lock_env([("GOOSE_MODE", Some("auto"))]); self.test_model_listing().await?; self.test_basic_response(&self.session_id_for_test("basic_response")) .await?; @@ -393,8 +402,104 @@ impl ProviderTester { self.test_context_length_exceeded_error(&self.session_id_for_test("context_length")) .await?; } + drop(_guard); + // codex: one-shot subprocess, no bidirectional control protocol + if self.name != "codex" { + self.test_permission_allow().await?; + self.test_permission_deny().await?; + } Ok(()) } + + async fn run_permission_test(&self, permission: Permission, label: &str) -> Result<()> { + // Guard must live through agent.reply() — providers read GOOSE_MODE at spawn time. + let _guard = env_lock::lock_env([("GOOSE_MODE", Some("approve"))]); + let provider = if self.is_cli_provider { + create_with_named_model( + &self.name.to_lowercase(), + &self.provider.get_model_config().model_name, + vec![self.mcp_extension.clone()], + ) + .await + .map_err(|e| anyhow::anyhow!("{}", e))? + } else { + self.provider.clone() + }; + + let temp_dir = tempfile::tempdir()?; + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let permission_manager = Arc::new(PermissionManager::new(temp_dir.path().to_path_buf())); + let agent = Agent::with_config(AgentConfig::new( + session_manager.clone(), + permission_manager, + None, + GooseMode::Approve, + true, + GoosePlatform::GooseCli, + )); + + let session = session_manager + .create_session( + std::env::current_dir()?, + "permission_test".to_string(), + SessionType::User, + ) + .await?; + + agent.update_provider(provider, &session.id).await?; + agent + .add_extension(self.mcp_extension.clone(), &session.id) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let message = + Message::user().with_text("Use the get_code tool and output only its result."); + let session_config = SessionConfig { + id: session.id, + schedule_id: None, + max_turns: Some(5), + retry_config: None, + }; + + let mut stream = agent.reply(message, session_config, None).await?; + let mut saw_action_required = false; + + while let Some(event) = stream.next().await { + let event = event?; + if let AgentEvent::Message(ref msg) = event { + for content in &msg.content { + if let MessageContent::ActionRequired(ar) = content { + if let ActionRequiredData::ToolConfirmation { ref id, .. } = ar.data { + saw_action_required = true; + agent + .handle_confirmation( + id.clone(), + PermissionConfirmation { + principal_type: PrincipalType::Tool, + permission: permission.clone(), + }, + ) + .await; + } + } + } + } + } + + assert!(saw_action_required); + println!("=== {}::{} ===", self.name, label); + Ok(()) + } + + async fn test_permission_allow(&self) -> Result<()> { + self.run_permission_test(Permission::AllowOnce, "permission_allow") + .await + } + + async fn test_permission_deny(&self) -> Result<()> { + self.run_permission_test(Permission::DenyOnce, "permission_deny") + .await + } } fn load_env() { @@ -507,7 +612,7 @@ async fn test_provider( ExtensionManagerCapabilities { mcpui: false }, )); extension_manager - .add_extension(mcp_extension, None, None, None) + .add_extension(mcp_extension.clone(), None, None, None) .await .expect("failed to add extension"); @@ -517,6 +622,7 @@ async fn test_provider( extension_manager, is_cli_provider, model_switch_name.map(String::from), + mcp_extension, ); let _mcp = mcp; let result = tester.run_test_suite().await; diff --git a/scripts/test_providers_lib.sh b/scripts/test_providers_lib.sh index b56f13a89982..2beac0d4eae4 100755 --- a/scripts/test_providers_lib.sh +++ b/scripts/test_providers_lib.sh @@ -17,7 +17,7 @@ litellm -> gpt-4o-mini sagemaker_tgi -> sagemaker-tgi-endpoint github_copilot -> gpt-4.1 chatgpt_codex -> gpt-5.1-codex -claude-code -> claude-sonnet-4-20250514 +claude-code -> default codex -> gpt-5.2-codex gemini-cli -> gemini-2.5-pro cursor-agent -> auto