Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ impl Agent {
&self,
tool_call: mcp_core::tool::ToolCall,
request_id: String,
cancellation_token: Option<CancellationToken>,
) -> (String, Result<ToolCallResult, ToolError>) {
// Check if this tool call should be allowed based on repetition monitoring
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
Expand Down Expand Up @@ -345,10 +346,12 @@ impl Agent {

let task_config =
TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx);

subagent_execute_task_tool::run_tasks(
tool_call.arguments.clone(),
task_config,
&self.tasks_manager,
cancellation_token,
)
.await
} else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX {
Expand Down Expand Up @@ -914,7 +917,7 @@ impl Agent {
for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() {
let (req_id, tool_result) = self
.dispatch_tool_call(tool_call, request.id.clone())
.dispatch_tool_call(tool_call, request.id.clone(), cancel_token.clone())
.await;

tool_futures.push((
Expand Down Expand Up @@ -951,6 +954,7 @@ impl Agent {
tool_futures_arc.clone(),
&mut permission_manager,
message_tool_response.clone(),
cancel_token.clone(),
);

while let Some(msg) = tool_approval_stream.try_next().await? {
Expand Down
11 changes: 0 additions & 11 deletions crates/goose/src/agents/sub_recipe_execution_tool/mod.rs

This file was deleted.

186 changes: 0 additions & 186 deletions crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs

This file was deleted.

31 changes: 0 additions & 31 deletions crates/goose/src/agents/sub_recipe_execution_tool/workers.rs

This file was deleted.

22 changes: 20 additions & 2 deletions crates/goose/src/agents/subagent_execution_tool/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;

const EXECUTION_STATUS_COMPLETED: &str = "completed";
const DEFAULT_MAX_WORKERS: usize = 10;
Expand All @@ -21,14 +22,22 @@ pub async fn execute_single_task(
task: &Task,
notifier: mpsc::Sender<JsonRpcMessage>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {
let start_time = Instant::now();
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
vec![task.clone()],
DisplayMode::SingleTaskOutput,
notifier,
cancellation_token.clone(),
));
let result = process_task(task, task_execution_tracker.clone(), task_config).await;
let result = process_task(
task,
task_execution_tracker.clone(),
task_config,
cancellation_token.unwrap_or_default(),
)
.await;

// Complete the task in the tracker
task_execution_tracker
Expand All @@ -49,11 +58,13 @@ pub async fn execute_tasks_in_parallel(
tasks: Vec<Task>,
notifier: Sender<JsonRpcMessage>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
tasks.clone(),
DisplayMode::MultipleTasksOutput,
notifier,
cancellation_token.clone(),
));
let start_time = Instant::now();
let task_count = tasks.len();
Expand All @@ -71,7 +82,12 @@ pub async fn execute_tasks_in_parallel(
return create_error_response(e);
}

let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone());
let shared_state = create_shared_state(
task_rx,
result_tx,
task_execution_tracker.clone(),
cancellation_token.unwrap_or_default(),
);

let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS);
let mut worker_handles = Vec::new();
Expand Down Expand Up @@ -135,12 +151,14 @@ fn create_shared_state(
task_rx: mpsc::Receiver<Task>,
result_tx: mpsc::Sender<TaskResult>,
task_execution_tracker: Arc<TaskExecutionTracker>,
cancellation_token: CancellationToken,
) -> Arc<SharedState> {
Arc::new(SharedState {
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
result_sender: result_tx,
active_workers: Arc::new(AtomicUsize::new(0)),
task_execution_tracker,
cancellation_token,
})
}

Expand Down
14 changes: 11 additions & 3 deletions crates/goose/src/agents/subagent_execution_tool/lib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ use crate::agents::subagent_task_config::TaskConfig;
use rmcp::model::JsonRpcMessage;
use serde_json::{json, Value};
use tokio::sync::mpsc::Sender;
use tokio_util::sync::CancellationToken;

pub async fn execute_tasks(
input: Value,
execution_mode: ExecutionMode,
notifier: Sender<JsonRpcMessage>,
task_config: TaskConfig,
tasks_manager: &TasksManager,
cancellation_token: Option<CancellationToken>,
) -> Result<Value, String> {
let task_ids: Vec<String> = serde_json::from_value(
input
Expand All @@ -31,7 +33,8 @@ pub async fn execute_tasks(
match execution_mode {
ExecutionMode::Sequential => {
if task_count == 1 {
let response = execute_single_task(&tasks[0], notifier, task_config).await;
let response =
execute_single_task(&tasks[0], notifier, task_config, cancellation_token).await;
handle_response(response)
} else {
Err("Sequential execution mode requires exactly one task".to_string())
Expand All @@ -47,8 +50,13 @@ pub async fn execute_tasks(
}
))
} else {
let response: ExecutionResponse =
execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await;
let response: ExecutionResponse = execute_tasks_in_parallel(
tasks,
notifier.clone(),
task_config,
cancellation_token,
)
.await;
handle_response(response)
}
}
Expand Down
Loading
Loading