diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index abef20b87bf6..3f2b983efbe1 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -10,7 +10,6 @@ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; -use goose::execution::SessionExecutionMode; use goose::mcp_utils::ToolResult; use goose::permission::{Permission, PermissionConfirmation}; use goose::session::SessionManager; @@ -207,10 +206,7 @@ async fn reply_handler( let task_tx = tx.clone(); drop(tokio::spawn(async move { - let agent = match state - .get_agent(session_id.clone(), SessionExecutionMode::Interactive) - .await - { + let agent = match state.get_agent(session_id.clone()).await { Ok(agent) => agent, Err(e) => { tracing::error!("Failed to get session agent: {}", e); diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index cf6696a61fe6..b8b916bf9102 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,6 +1,5 @@ use axum::http::StatusCode; use goose::execution::manager::AgentManager; -use goose::execution::SessionExecutionMode; use goose::scheduler_trait::SchedulerTrait; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; @@ -46,14 +45,8 @@ impl AppState { } } - pub async fn get_agent( - &self, - session_id: String, - mode: SessionExecutionMode, - ) -> anyhow::Result> { - self.agent_manager - .get_or_create_agent(session_id, mode) - .await + pub async fn get_agent(&self, session_id: String) -> anyhow::Result> { + self.agent_manager.get_or_create_agent(session_id).await } /// Get agent for route handlers - always uses Interactive mode and converts any error to 500 @@ -61,11 +54,9 @@ impl AppState { &self, session_id: String, ) -> Result, StatusCode> { - self.get_agent(session_id, SessionExecutionMode::Interactive) - .await - .map_err(|e| { - tracing::error!("Failed to get agent: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - }) + self.get_agent(session_id).await.map_err(|e| { + tracing::error!("Failed to get agent: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + }) } } diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index fc3257a7bb71..25eb8bdc4a6d 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -1,6 +1,3 @@ -//! Agent lifecycle management with session isolation - -use super::SessionExecutionMode; use crate::agents::Agent; use crate::config::APP_STRATEGY; use crate::model::ModelConfig; @@ -112,49 +109,27 @@ impl AgentManager { Ok(()) } - pub async fn get_or_create_agent( - &self, - session_id: String, - mode: SessionExecutionMode, - ) -> Result> { - let agent = { + pub async fn get_or_create_agent(&self, session_id: String) -> Result> { + { let mut sessions = self.sessions.write().await; - if let Some(agent) = sessions.get(&session_id) { - debug!("Found existing agent for session {}", session_id); - return Ok(Arc::clone(agent)); - } - - info!( - "Creating new agent for session {} with mode {}", - session_id, mode - ); - let agent = Arc::new(Agent::new()); - sessions.put(session_id.clone(), Arc::clone(&agent)); - agent - }; - - match &mode { - SessionExecutionMode::Interactive | SessionExecutionMode::Background => { - debug!("Setting scheduler on agent for session {}", session_id); - agent.set_scheduler(Arc::clone(&self.scheduler)).await; - } - SessionExecutionMode::SubTask { .. } => { - debug!( - "SubTask mode for session {}, skipping scheduler setup", - session_id - ); + if let Some(existing) = sessions.get(&session_id) { + return Ok(Arc::clone(existing)); } } + let agent = Arc::new(Agent::new()); + agent.set_scheduler(Arc::clone(&self.scheduler)).await; if let Some(provider) = &*self.default_provider.read().await { - debug!( - "Setting default provider on agent for session {}", - session_id - ); - let _ = agent.update_provider(Arc::clone(provider)).await; + agent.update_provider(Arc::clone(provider)).await?; } - Ok(agent) + let mut sessions = self.sessions.write().await; + if let Some(existing) = sessions.get(&session_id) { + Ok(Arc::clone(existing)) + } else { + sessions.put(session_id, agent.clone()); + Ok(agent) + } } pub async fn remove_session(&self, session_id: &str) -> Result<()> { diff --git a/crates/goose/tests/execution_tests.rs b/crates/goose/tests/execution_tests.rs index 55ec61134f67..07e8932563b7 100644 --- a/crates/goose/tests/execution_tests.rs +++ b/crates/goose/tests/execution_tests.rs @@ -33,24 +33,15 @@ mod execution_tests { let session1 = uuid::Uuid::new_v4().to_string(); let session2 = uuid::Uuid::new_v4().to_string(); - let agent1 = manager - .get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive) - .await - .unwrap(); + let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap(); - let agent2 = manager - .get_or_create_agent(session2.clone(), SessionExecutionMode::Interactive) - .await - .unwrap(); + let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap(); // Different sessions should have different agents assert!(!Arc::ptr_eq(&agent1, &agent2)); // Getting the same session should return the same agent - let agent1_again = manager - .get_or_create_agent(session1, SessionExecutionMode::chat()) - .await - .unwrap(); + let agent1_again = manager.get_or_create_agent(session1).await.unwrap(); assert!(Arc::ptr_eq(&agent1, &agent1_again)); @@ -66,18 +57,12 @@ mod execution_tests { let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); for session in &sessions { - manager - .get_or_create_agent(session.clone(), SessionExecutionMode::chat()) - .await - .unwrap(); + manager.get_or_create_agent(session.clone()).await.unwrap(); } // Create a new session after cleanup let new_session = "new-session".to_string(); - let _new_agent = manager - .get_or_create_agent(new_session, SessionExecutionMode::chat()) - .await - .unwrap(); + let _new_agent = manager.get_or_create_agent(new_session).await.unwrap(); assert_eq!(manager.session_count().await, 100); } @@ -89,18 +74,13 @@ mod execution_tests { let manager = AgentManager::instance().await.unwrap(); let session = String::from("remove-test"); - manager - .get_or_create_agent(session.clone(), SessionExecutionMode::chat()) - .await - .unwrap(); + manager.get_or_create_agent(session.clone()).await.unwrap(); assert!(manager.has_session(&session).await); manager.remove_session(&session).await.unwrap(); assert!(!manager.has_session(&session).await); assert!(manager.remove_session(&session).await.is_err()); - - AgentManager::reset_for_test(); } #[tokio::test] @@ -115,9 +95,7 @@ mod execution_tests { let mgr = Arc::clone(&manager); let sess = session.clone(); handles.push(tokio::spawn(async move { - mgr.get_or_create_agent(sess, SessionExecutionMode::chat()) - .await - .unwrap() + mgr.get_or_create_agent(sess).await.unwrap() })); } @@ -132,33 +110,6 @@ mod execution_tests { } assert_eq!(manager.session_count().await, 1); - - AgentManager::reset_for_test(); - } - - #[tokio::test] - #[serial] - async fn test_different_modes_same_session() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - let session_id = String::from("mode-test"); - - // Create initial agent - let agent1 = manager - .get_or_create_agent(session_id.clone(), SessionExecutionMode::chat()) - .await - .unwrap(); - - // Get same session with different mode - should return same agent - // (mode is stored but agent is reused) - let agent2 = manager - .get_or_create_agent(session_id.clone(), SessionExecutionMode::Background) - .await - .unwrap(); - - assert!(Arc::ptr_eq(&agent1, &agent2)); - - AgentManager::reset_for_test(); } #[tokio::test] @@ -176,10 +127,7 @@ mod execution_tests { let sess = session_id.clone(); let mgr_clone = Arc::clone(&manager); handles.push(tokio::spawn(async move { - mgr_clone - .get_or_create_agent(sess, SessionExecutionMode::Interactive) - .await - .unwrap() + mgr_clone.get_or_create_agent(sess).await.unwrap() })); } @@ -190,18 +138,13 @@ mod execution_tests { .map(|r| r.unwrap()) .collect(); - // All should be the same agent (double-check pattern should prevent duplicates) for agent in &agents[1..] { assert!( Arc::ptr_eq(&agents[0], agent), "All concurrent requests should get the same agent" ); } - - // Only one session should exist assert_eq!(manager.session_count().await, 1); - - AgentManager::reset_for_test(); } #[tokio::test] @@ -233,8 +176,6 @@ mod execution_tests { } else { env::remove_var("GOOSE_DEFAULT_MODEL"); } - - AgentManager::reset_for_test(); } #[tokio::test] @@ -260,14 +201,9 @@ mod execution_tests { manager.set_default_provider(Arc::new(test_provider)).await; let session = String::from("provider-test"); - let _agent = manager - .get_or_create_agent(session.clone(), SessionExecutionMode::Interactive) - .await - .unwrap(); + let _agent = manager.get_or_create_agent(session.clone()).await.unwrap(); assert!(manager.has_session(&session).await); - - AgentManager::reset_for_test(); } #[tokio::test] @@ -281,10 +217,7 @@ mod execution_tests { let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); for session in &sessions { - manager - .get_or_create_agent(session.clone(), SessionExecutionMode::chat()) - .await - .unwrap(); + manager.get_or_create_agent(session.clone()).await.unwrap(); // Small delay to ensure different timestamps tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; } @@ -292,23 +225,20 @@ mod execution_tests { // Access the first session again to update its last_used tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; manager - .get_or_create_agent(sessions[0].clone(), SessionExecutionMode::Interactive) + .get_or_create_agent(sessions[0].clone()) .await .unwrap(); // Now create a 101st session - should evict session2 (least recently used) let session101 = String::from("session-101"); manager - .get_or_create_agent(session101.clone(), SessionExecutionMode::Interactive) + .get_or_create_agent(session101.clone()) .await .unwrap(); - // session1 should still exist (recently accessed) - // session2 should be evicted (least recently used) assert!(manager.has_session(&sessions[0]).await); assert!(!manager.has_session(&sessions[1]).await); assert!(manager.has_session(&session101).await); - AgentManager::reset_for_test(); } #[tokio::test] @@ -322,7 +252,5 @@ mod execution_tests { let result = manager.remove_session(&session).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("not found")); - - AgentManager::reset_for_test(); } }