diff --git a/Cargo.lock b/Cargo.lock index cf3088070085..3e417bacd754 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2705,6 +2705,7 @@ dependencies = [ "jsonschema", "nix 0.30.1", "once_cell", + "open", "rand 0.8.5", "regex", "rmcp", @@ -3531,6 +3532,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + [[package]] name = "is-terminal" version = "0.4.16" @@ -3542,6 +3552,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -4418,6 +4438,17 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "open" +version = "5.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2483562e62ea94312f3576a7aca397306df7990b8d89033e18766744377ef95" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl" version = "0.10.73" diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c8651e6a9433..d5533aa80ecf 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -23,6 +23,7 @@ use crate::agents::recipe_tools::dynamic_task_tools::{ use crate::agents::retry::{RetryManager, RetryResult}; use crate::agents::router_tools::ROUTER_LLM_SEARCH_TOOL_NAME; use crate::agents::sub_recipe_manager::SubRecipeManager; +use crate::agents::subagent_execution_tool::lib::ExecutionMode; use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, }; @@ -297,6 +298,7 @@ impl Agent { permission_check_result: &PermissionCheckResult, message_tool_response: Arc>, cancel_token: Option, + session: Option, ) -> Result> { let mut tool_futures: Vec<(String, ToolStream)> = Vec::new(); @@ -304,7 +306,12 @@ impl Agent { for request in &permission_check_result.approved { if let Ok(tool_call) = request.tool_call.clone() { let (req_id, tool_result) = self - .dispatch_tool_call(tool_call, request.id.clone(), cancel_token.clone()) + .dispatch_tool_call( + tool_call, + request.id.clone(), + cancel_token.clone(), + session.clone(), + ) .await; tool_futures.push(( @@ -384,6 +391,7 @@ impl Agent { tool_call: CallToolRequestParam, request_id: String, cancellation_token: Option, + session: Option, ) -> (String, Result) { if tool_call.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME { let arguments = tool_call @@ -451,16 +459,89 @@ impl Agent { .dispatch_sub_recipe_tool_call(&tool_call.name, arguments, &self.tasks_manager) .await } else if tool_call.name == SUBAGENT_EXECUTE_TASK_TOOL_NAME { - let provider = self.provider().await.ok(); - let arguments = tool_call - .arguments - .clone() - .map(Value::Object) - .unwrap_or(Value::Object(serde_json::Map::new())); + let provider = match self.provider().await { + Ok(p) => p, + Err(_) => { + return ( + request_id, + Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Provider is required".to_string(), + None, + )), + ); + } + }; + let session = match session.as_ref() { + Some(s) => s, + None => { + return ( + request_id, + Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Session is required".to_string(), + None, + )), + ); + } + }; + let parent_session_id = session.id.to_string(); + let parent_working_dir = session.working_dir.clone(); + + let task_config = TaskConfig::new( + provider, + parent_session_id, + parent_working_dir, + get_enabled_extensions(), + ); + + let arguments = match tool_call.arguments.clone() { + Some(args) => Value::Object(args), + None => { + return ( + request_id, + Err(ErrorData::new( + ErrorCode::INVALID_PARAMS, + "Tool call arguments are required".to_string(), + None, + )), + ); + } + }; + let task_ids: Vec = match arguments.get("task_ids") { + Some(v) => match serde_json::from_value(v.clone()) { + Ok(ids) => ids, + Err(_) => { + return ( + request_id, + Err(ErrorData::new( + ErrorCode::INVALID_PARAMS, + "Invalid task_ids format".to_string(), + None, + )), + ); + } + }, + None => { + return ( + request_id, + Err(ErrorData::new( + ErrorCode::INVALID_PARAMS, + "task_ids parameter is required".to_string(), + None, + )), + ); + } + }; + + let execution_mode = arguments + .get("execution_mode") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or(ExecutionMode::Sequential); - let task_config = TaskConfig::new(provider); subagent_execute_task_tool::run_tasks( - arguments, + task_ids, + execution_mode, task_config, &self.tasks_manager, cancellation_token, @@ -1162,6 +1243,7 @@ impl Agent { &permission_check_result, message_tool_response.clone(), cancel_token.clone(), + session.clone(), ).await?; let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); @@ -1172,6 +1254,7 @@ impl Agent { tool_futures_arc.clone(), message_tool_response.clone(), cancel_token.clone(), + session.clone(), &inspection_results, ); diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index b5cf270bf912..39bf903a3579 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -16,7 +16,6 @@ mod router_tool_selector; mod router_tools; mod schedule_tool; pub mod sub_recipe_manager; -pub mod subagent; pub mod subagent_execution_tool; pub mod subagent_handler; mod subagent_task_config; @@ -30,6 +29,5 @@ pub use agent::{Agent, AgentEvent}; pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; -pub use subagent::{SubAgent, SubAgentProgress, SubAgentStatus}; pub use subagent_task_config::TaskConfig; pub use types::{FrontendTool, RetryConfig, SessionConfig, SuccessCheck}; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 641e65a0341c..c4b7b34d1b20 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -84,48 +84,6 @@ impl Agent { Ok((tools, toolshim_tools, system_prompt)) } - /// Generate a response from the LLM provider - /// Handles toolshim transformations if needed - pub(crate) async fn generate_response_from_provider( - provider: Arc, - system_prompt: &str, - messages: &[Message], - tools: &[Tool], - toolshim_tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let config = provider.get_model_config(); - - // Convert tool messages to text if toolshim is enabled - let messages_for_provider = if config.toolshim { - convert_tool_messages_to_text(messages) - } else { - Conversation::new_unvalidated(messages.to_vec()) - }; - - // Call the provider to get a response - let (mut response, mut usage) = provider - .complete(system_prompt, messages_for_provider.messages(), tools) - .await?; - - // Ensure we have token counts, estimating if necessary - usage - .ensure_tokens( - system_prompt, - messages_for_provider.messages(), - &response, - tools, - ) - .await?; - - crate::providers::base::set_current_model(&usage.model); - - if config.toolshim { - response = toolshim_postprocess(response, toolshim_tools).await?; - } - - Ok((response, usage)) - } - /// Stream a response from the LLM provider. /// Handles toolshim transformations if needed pub(crate) async fn stream_response_from_provider( diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs deleted file mode 100644 index 7daac47255d8..000000000000 --- a/crates/goose/src/agents/subagent.rs +++ /dev/null @@ -1,334 +0,0 @@ -use crate::agents::subagent_task_config::DEFAULT_SUBAGENT_MAX_TURNS; -use crate::{ - agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, - config::get_all_extensions, - prompt_template::render_global_file, - providers::errors::ProviderError, -}; -use anyhow::anyhow; -use chrono::{DateTime, Utc}; -use rmcp::model::Tool; -use rmcp::model::{ErrorCode, ErrorData}; -use serde::{Deserialize, Serialize}; -// use serde_json::{self}; -use crate::conversation::message::{Message, MessageContent, ToolRequest}; -use crate::conversation::Conversation; -use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{Mutex, RwLock}; -use tokio_util::sync::CancellationToken; -use tracing::{debug, error, instrument}; - -/// Status of a subagent -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum SubAgentStatus { - Ready, // Ready to process messages - Processing, // Currently working on a task - Completed(String), // Task completed (with optional message for success/error) - Terminated, // Manually terminated -} - -/// Progress information for a subagent -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SubAgentProgress { - pub subagent_id: String, - pub status: SubAgentStatus, - pub message: String, - pub turn: usize, - pub max_turns: Option, - pub timestamp: DateTime, -} - -/// A specialized agent that can handle specific tasks independently -pub struct SubAgent { - pub id: String, - pub conversation: Arc>, - pub status: Arc>, - pub config: TaskConfig, - pub turn_count: Arc>, - pub created_at: DateTime, - pub extension_manager: Arc>, -} - -impl SubAgent { - /// Create a new subagent with the given configuration and provider - #[instrument(skip(task_config))] - pub async fn new(task_config: TaskConfig) -> Result, anyhow::Error> { - debug!("Creating new subagent with id: {}", task_config.id); - - // Create a new extension manager for this subagent - let extension_manager = ExtensionManager::new(); - - // Determine which extensions to add: - // 1. If task_config.extensions is Some(vec), use those specific extensions - // 2. If task_config.extensions is None, use all enabled extensions (backward compatibility) - - let extensions_to_add = if let Some(ref extensions) = task_config.extensions { - // Use the explicitly specified extensions - extensions.clone() - } else { - // Default behavior: use all enabled extensions - get_all_extensions() - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect() - }; - - // Add the determined extensions to the subagent's extension manager - for extension in extensions_to_add { - if let Err(e) = extension_manager.add_extension(extension).await { - debug!("Failed to add extension to subagent: {}", e); - // Continue with other extensions even if one fails - } - } - - let subagent = Arc::new(SubAgent { - id: task_config.id.clone(), - conversation: Arc::new(Mutex::new(Conversation::new_unvalidated(Vec::new()))), - status: Arc::new(RwLock::new(SubAgentStatus::Ready)), - config: task_config, - turn_count: Arc::new(Mutex::new(0)), - created_at: Utc::now(), - extension_manager: Arc::new(RwLock::new(extension_manager)), - }); - - debug!("Subagent {} created successfully", subagent.id); - Ok(subagent) - } - - /// Update the status of the subagent - async fn set_status(&self, status: SubAgentStatus) { - // Update the status first, then release the lock - { - let mut current_status = self.status.write().await; - *current_status = status.clone(); - } // Write lock is released here! - } - - /// Process a message and generate a response using the subagent's provider - #[instrument(skip(self, message))] - pub async fn reply_subagent( - &self, - message: String, - task_config: TaskConfig, - ) -> Result { - debug!("Processing message for subagent {}", self.id); - - // Get provider from task config - let provider = self - .config - .provider - .as_ref() - .ok_or_else(|| anyhow!("No provider configured for subagent"))?; - - // Set status to processing - self.set_status(SubAgentStatus::Processing).await; - - // Add user message to conversation - let user_message = Message::user().with_text(message.clone()); - { - let mut conversation = self.conversation.lock().await; - conversation.push(user_message.clone()); - } - - // Get the current conversation for context - let mut messages = { - let conversation = self.conversation.lock().await; - conversation.clone() - }; - - // Get tools from the subagent's own extension manager - let tools: Vec = self - .extension_manager - .read() - .await - .get_prefixed_tools(None) - .await - .unwrap_or_default(); - - let toolshim_tools: Vec = vec![]; - - // Build system prompt using the template - let system_prompt = self.build_system_prompt(&tools).await?; - - // Generate response from provider with loop for tool processing (max_turns iterations) - let mut loop_count = 0; - let max_turns = self.config.max_turns.unwrap_or(DEFAULT_SUBAGENT_MAX_TURNS); - let mut last_error: Option = None; - - // Generate response from provider - loop { - loop_count += 1; - - match Agent::generate_response_from_provider( - Arc::clone(provider), - &system_prompt, - messages.messages(), - &tools, - &toolshim_tools, - ) - .await - { - Ok((response, _usage)) => { - // Process any tool calls in the response - let tool_requests: Vec = response - .content - .iter() - .filter_map(|content| { - if let MessageContent::ToolRequest(req) = content { - Some(req.clone()) - } else { - None - } - }) - .collect(); - - // If there are no tool requests, we're done - if tool_requests.is_empty() || loop_count >= max_turns { - self.add_message(response.clone()).await; - messages.push(response.clone()); - - // Set status back to ready - self.set_status(SubAgentStatus::Completed("Completed!".to_string())) - .await; - break; - } - - // Add the assistant message with tool calls to the conversation - messages.push(response.clone()); - - // Process each tool request and create user response messages - for request in &tool_requests { - if let Ok(tool_call) = &request.tool_call { - // Handle platform tools or dispatch to extension manager - let tool_result = match self - .extension_manager - .read() - .await - .dispatch_tool_call(tool_call.clone(), CancellationToken::default()) - .await - { - Ok(result) => result.result.await, - Err(e) => Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - e.to_string(), - None, - )), - }; - - match tool_result { - Ok(result) => { - // Create a user message with the tool response - let tool_response_message = Message::user() - .with_tool_response(request.id.clone(), Ok(result.clone())); - messages.push(tool_response_message); - } - Err(e) => { - // Create a user message with the tool error - let tool_error_message = Message::user().with_tool_response( - request.id.clone(), - Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - e.to_string(), - None, - )), - ); - messages.push(tool_error_message); - } - } - } - } - - // Continue the loop to get the next response from the provider - } - Err(ProviderError::ContextLengthExceeded(_)) => { - self.set_status(SubAgentStatus::Completed( - "Context length exceeded".to_string(), - )) - .await; - last_error = Some(anyhow::anyhow!("Context length exceeded")); - break; - } - Err(ProviderError::RateLimitExceeded { .. }) => { - self.set_status(SubAgentStatus::Completed("Rate limit exceeded".to_string())) - .await; - last_error = Some(anyhow::anyhow!("Rate limit exceeded")); - break; - } - Err(e) => { - self.set_status(SubAgentStatus::Completed(format!("Error: {}", e))) - .await; - error!("Error: {}", e); - last_error = Some(anyhow::anyhow!("Provider error: {}", e)); - break; - } - } - } - - // Handle error cases or return the last message - if let Some(error) = last_error { - Err(error) - } else { - Ok(messages) - } - } - - /// Add a message to the conversation (for tracking agent responses) - async fn add_message(&self, message: Message) { - let mut conversation = self.conversation.lock().await; - conversation.push(message); - } - - /// Build the system prompt for the subagent using the template - async fn build_system_prompt(&self, available_tools: &[Tool]) -> Result { - let mut context = HashMap::new(); - - // Add basic context - context.insert( - "current_date_time", - serde_json::Value::String(Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string()), - ); - context.insert("subagent_id", serde_json::Value::String(self.id.clone())); - - // Add max turns if configured - if let Some(max_turns) = self.config.max_turns { - context.insert( - "max_turns", - serde_json::Value::Number(serde_json::Number::from(max_turns)), - ); - } - - // Add available tools with descriptions for better context - let tools_with_descriptions: Vec = available_tools - .iter() - .map(|t| { - if let Some(description) = &t.description { - format!("{}: {}", t.name, description) - } else { - t.name.to_string() - } - }) - .collect(); - - context.insert( - "available_tools", - serde_json::Value::String(if tools_with_descriptions.is_empty() { - "None".to_string() - } else { - tools_with_descriptions.join(", ") - }), - ); - - // Add tool count for context - context.insert( - "tool_count", - serde_json::Value::Number(serde_json::Number::from(available_tools.len())), - ); - - // Render the subagent system prompt template - let system_prompt = render_global_file("subagent_system.md", &context) - .map_err(|e| anyhow!("Failed to render subagent system prompt: {}", e))?; - - Ok(system_prompt) - } -} diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 695d8ddb97ab..48947c1438a8 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -12,21 +12,13 @@ use tokio::sync::mpsc::Sender; use tokio_util::sync::CancellationToken; pub async fn execute_tasks( - input: Value, + task_ids: Vec, execution_mode: ExecutionMode, notifier: Sender, task_config: TaskConfig, tasks_manager: &TasksManager, cancellation_token: Option, ) -> Result { - let task_ids: Vec = serde_json::from_value( - input - .get("task_ids") - .ok_or("Missing task_ids field")? - .clone(), - ) - .map_err(|e| format!("Failed to parse task_ids: {}", e))?; - let tasks = tasks_manager.get_tasks(&task_ids).await?; let task_count = tasks.len(); diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index 474912f0c416..b70f71d3f205 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -1,14 +1,12 @@ use std::borrow::Cow; -use rmcp::model::{Content, ErrorCode, ErrorData, ServerNotification, Tool, ToolAnnotations}; -use serde_json::Value; - use crate::agents::subagent_task_config::TaskConfig; use crate::agents::{ subagent_execution_tool::lib::execute_tasks, subagent_execution_tool::task_types::ExecutionMode, subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; +use rmcp::model::{Content, ErrorCode, ErrorData, ServerNotification, Tool, ToolAnnotations}; use rmcp::object; use tokio::sync::mpsc; use tokio_stream; @@ -62,7 +60,8 @@ pub fn create_subagent_execute_task_tool() -> Tool { } pub async fn run_tasks( - execute_data: Value, + task_ids: Vec, + execution_mode: ExecutionMode, task_config: TaskConfig, tasks_manager: &TasksManager, cancellation_token: Option, @@ -71,14 +70,8 @@ pub async fn run_tasks( let tasks_manager_clone = tasks_manager.clone(); let result_future = async move { - let execute_data_clone = execute_data.clone(); - let execution_mode = execute_data_clone - .get("execution_mode") - .and_then(|v| serde_json::from_value::(v.clone()).ok()) - .unwrap_or_default(); - match execute_tasks( - execute_data, + task_ids, execution_mode, notification_tx, task_config, diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs index 96d9f25d0f33..ccb9fdb717cd 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -74,7 +74,7 @@ async fn handle_inline_recipe_task( mut task_config: TaskConfig, cancellation_token: CancellationToken, ) -> Result { - use crate::agents::subagent_handler::run_complete_subagent_task_with_options; + use crate::agents::subagent_handler::run_complete_subagent_task; use crate::recipe::Recipe; let recipe_value = task @@ -91,14 +91,23 @@ async fn handle_inline_recipe_task( .and_then(|v| v.as_bool()) .unwrap_or(false); - task_config.extensions = recipe.extensions.clone(); + if let Some(exts) = recipe.extensions { + if !exts.is_empty() { + task_config.extensions = exts.clone(); + } + } let instruction = recipe .instructions .or(recipe.prompt) .ok_or_else(|| "No instructions or prompt in recipe".to_string())?; + let result = tokio::select! { - result = run_complete_subagent_task_with_options(instruction, task_config, return_last_only) => result, + result = run_complete_subagent_task( + instruction, + task_config, + return_last_only, + ) => result, _ = cancellation_token.cancelled() => { return Err("Task cancelled".to_string()); } diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index c2cb724a8eb0..83f4cfb054fe 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,35 +1,30 @@ -use crate::agents::subagent::SubAgent; -use crate::agents::subagent_task_config::TaskConfig; -use anyhow::Result; +use crate::{ + agents::{subagent_task_config::TaskConfig, AgentEvent, SessionConfig}, + conversation::{message::Message, Conversation}, + execution::manager::AgentManager, + session::SessionManager, +}; +use anyhow::{anyhow, Result}; +use futures::future::BoxFuture; +use futures::StreamExt; use rmcp::model::{ErrorCode, ErrorData}; - -/// Standalone function to run a complete subagent task -pub async fn run_complete_subagent_task( - text_instruction: String, - task_config: TaskConfig, -) -> Result { - run_complete_subagent_task_with_options(text_instruction, task_config, false).await -} +use tracing::debug; /// Standalone function to run a complete subagent task with output options -pub async fn run_complete_subagent_task_with_options( +pub async fn run_complete_subagent_task( text_instruction: String, task_config: TaskConfig, return_last_only: bool, ) -> Result { - // Create the subagent with the parent agent's provider - let subagent = SubAgent::new(task_config.clone()).await.map_err(|e| { - ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to create subagent: {}", e), - None, - ) - })?; - - // Execute the subagent task - let messages = subagent - .reply_subagent(text_instruction, task_config) - .await?; + let messages = get_agent_messages(text_instruction, task_config) + .await + .map_err(|e| { + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Failed to execute task: {}", e), + None, + ) + })?; // Extract text content based on return_last_only flag let response_text = if return_last_only { @@ -94,3 +89,73 @@ pub async fn run_complete_subagent_task_with_options( // Return the result Ok(response_text) } + +fn get_agent_messages( + text_instruction: String, + task_config: TaskConfig, +) -> BoxFuture<'static, Result> { + Box::pin(async move { + let agent_manager = AgentManager::instance() + .await + .map_err(|e| anyhow!("Failed to create AgentManager: {}", e))?; + let parent_session_id = task_config.parent_session_id; + let working_dir = task_config.parent_working_dir; + let session = SessionManager::create_session( + working_dir.clone(), + format!("Subagent task for: {}", parent_session_id), + ) + .await + .map_err(|e| anyhow!("Failed to create a session for sub agent: {}", e))?; + + let agent = agent_manager + .get_or_create_agent(session.id.clone()) + .await + .map_err(|e| anyhow!("Failed to get sub agent session file path: {}", e))?; + agent + .update_provider(task_config.provider) + .await + .map_err(|e| anyhow!("Failed to set provider on sub agent: {}", e))?; + + for extension in task_config.extensions { + if let Err(e) = agent.add_extension(extension.clone()).await { + debug!( + "Failed to add extension '{}' to subagent: {}", + extension.name(), + e + ); + } + } + + let mut session_messages = + Conversation::new_unvalidated( + vec![Message::user().with_text(text_instruction.clone())], + ); + let session_config = SessionConfig { + id: session.id, + working_dir, + schedule_id: None, + execution_mode: None, + max_turns: task_config.max_turns.map(|v| v as u32), + retry_config: None, + }; + + let mut stream = agent + .reply(session_messages.clone(), Some(session_config), None) + .await + .map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?; + while let Some(message_result) = stream.next().await { + match message_result { + Ok(AgentEvent::Message(msg)) => session_messages.push(msg), + Ok(AgentEvent::McpNotification(_)) + | Ok(AgentEvent::ModelChange { .. }) + | Ok(AgentEvent::HistoryReplaced(_)) => {} // Handle informational events + Err(e) => { + tracing::error!("Error receiving message from subagent: {}", e); + break; + } + } + } + + Ok(session_messages) + }) +} diff --git a/crates/goose/src/agents/subagent_task_config.rs b/crates/goose/src/agents/subagent_task_config.rs index 5cdcf76b0228..3cafea2b8351 100644 --- a/crates/goose/src/agents/subagent_task_config.rs +++ b/crates/goose/src/agents/subagent_task_config.rs @@ -1,8 +1,9 @@ +use crate::agents::ExtensionConfig; use crate::providers::base::Provider; use std::env; use std::fmt; +use std::path::PathBuf; use std::sync::Arc; -use uuid::Uuid; /// Default maximum number of turns for task execution pub const DEFAULT_SUBAGENT_MAX_TURNS: usize = 25; @@ -13,17 +14,19 @@ pub const GOOSE_SUBAGENT_MAX_TURNS_ENV_VAR: &str = "GOOSE_SUBAGENT_MAX_TURNS"; /// Configuration for task execution with all necessary dependencies #[derive(Clone)] pub struct TaskConfig { - pub id: String, - pub provider: Option>, + pub provider: Arc, + pub parent_session_id: String, + pub parent_working_dir: PathBuf, + pub extensions: Vec, pub max_turns: Option, - pub extensions: Option>, } impl fmt::Debug for TaskConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TaskConfig") - .field("id", &self.id) .field("provider", &"") + .field("parent_session_id", &self.parent_session_id) + .field("parent_working_dir", &self.parent_working_dir) .field("max_turns", &self.max_turns) .field("extensions", &self.extensions) .finish() @@ -32,22 +35,23 @@ impl fmt::Debug for TaskConfig { impl TaskConfig { /// Create a new TaskConfig with all required dependencies - pub fn new(provider: Option>) -> Self { + pub fn new( + provider: Arc, + parent_session_id: String, + parent_working_dir: PathBuf, + extensions: Vec, + ) -> Self { Self { - id: Uuid::new_v4().to_string(), provider, + parent_session_id, + parent_working_dir, + extensions, max_turns: Some( env::var(GOOSE_SUBAGENT_MAX_TURNS_ENV_VAR) .ok() .and_then(|val| val.parse::().ok()) .unwrap_or(DEFAULT_SUBAGENT_MAX_TURNS), ), - extensions: None, } } - - /// Get a reference to the provider - pub fn provider(&self) -> Option<&Arc> { - self.provider.as_ref() - } } diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 4869d124f994..3cc7c5d08a94 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -29,7 +29,7 @@ impl From>> for ToolCallResult { } use super::agent::{tool_stream, ToolStream}; -use crate::agents::Agent; +use crate::agents::{Agent, SessionConfig}; use crate::conversation::message::{Message, ToolRequest}; use crate::tool_inspection::get_security_finding_id_from_results; @@ -53,6 +53,7 @@ impl Agent { tool_futures: Arc>>, message_tool_response: Arc>, cancellation_token: Option, + session: Option, inspection_results: &'a [crate::tool_inspection::InspectionResult], ) -> BoxStream<'a, anyhow::Result> { try_stream! { @@ -90,7 +91,7 @@ impl Agent { } if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { - let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await; + let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone(), session.clone()).await; let mut futures = tool_futures.lock().await; futures.push((req_id, match tool_result { diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index aac33feab5d0..1c366024427c 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -630,7 +630,7 @@ mod final_output_tool_tests { }; let (_, result) = agent - .dispatch_tool_call(tool_call, "request_id".to_string(), None) + .dispatch_tool_call(tool_call, "request_id".to_string(), None, None) .await; assert!(result.is_ok(), "Tool call should succeed"); diff --git a/crates/goose/tests/private_tests.rs b/crates/goose/tests/private_tests.rs index acd89881ebf7..8ac97aa15048 100644 --- a/crates/goose/tests/private_tests.rs +++ b/crates/goose/tests/private_tests.rs @@ -817,7 +817,7 @@ async fn test_schedule_tool_dispatch() { }; let (request_id, result) = agent - .dispatch_tool_call(tool_call, "test_dispatch".to_string(), None) + .dispatch_tool_call(tool_call, "test_dispatch".to_string(), None, None) .await; assert_eq!(request_id, "test_dispatch"); assert!(result.is_ok());