diff --git a/crates/goose-acp/src/bin/server.rs b/crates/goose-acp/src/bin/server.rs index 33ef93e0fff8..c3c6a9aed59a 100644 --- a/crates/goose-acp/src/bin/server.rs +++ b/crates/goose-acp/src/bin/server.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Parser; +use goose::config::paths::Paths; use goose_acp::server_factory::{AcpServer, AcpServerFactoryConfig}; use std::net::SocketAddr; use std::sync::Arc; @@ -36,12 +37,11 @@ async fn main() -> Result<()> { cli.builtins }; - let config = AcpServerFactoryConfig { + let server = Arc::new(AcpServer::new(AcpServerFactoryConfig { builtins, - ..Default::default() - }; - - let server = Arc::new(AcpServer::new(config)); + data_dir: Paths::data_dir(), + config_dir: Paths::config_dir(), + })); let router = goose_acp::transport::create_router(server); let addr: SocketAddr = format!("{}:{}", cli.host, cli.port).parse()?; diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index ed92665e328f..3f344914cb79 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -13,12 +13,12 @@ use goose::conversation::Conversation; use goose::mcp_utils::ToolResult; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::{Permission, PermissionConfirmation}; -use goose::providers::create; +use goose::providers::provider_registry::ProviderConstructor; use goose::session::session_manager::SessionType; use goose::session::{Session, SessionManager}; use rmcp::model::{CallToolResult, RawContent, ResourceContents, Role}; use sacp::schema::{ - AgentCapabilities, AuthenticateRequest, AuthenticateResponse, BlobResourceContents, + AgentCapabilities, AuthMethod, AuthenticateRequest, AuthenticateResponse, BlobResourceContents, CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource, EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse, LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest, @@ -46,16 +46,9 @@ struct GooseAcpSession { pub struct GooseAcpAgent { sessions: Arc>>, agent: Arc, - provider: Arc, -} - -pub struct AcpServerConfig { - pub provider: Arc, - pub builtins: Vec, - pub data_dir: std::path::PathBuf, - pub config_dir: std::path::PathBuf, - pub goose_mode: goose::config::GooseMode, - pub disable_session_naming: bool, + provider_factory: ProviderConstructor, + config_dir: std::path::PathBuf, + provider_initialized: tokio::sync::OnceCell, } fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result { @@ -278,54 +271,23 @@ impl GooseAcpAgent { Arc::clone(&self.agent.config.permission_manager) } - pub async fn new(builtins: Vec) -> Result { - let config = Config::global(); - - let provider_name: String = config - .get_goose_provider() - .map_err(|e| anyhow::anyhow!("No provider configured: {}", e))?; - - let model_name: String = config - .get_goose_model() - .map_err(|e| anyhow::anyhow!("No model configured: {}", e))?; - - let model_config = goose::model::ModelConfig { - model_name: model_name.clone(), - context_limit: None, - temperature: None, - max_tokens: None, - toolshim: false, - toolshim_model: None, - fast_model: None, - request_params: None, - }; - let provider = create(&provider_name, model_config).await?; - let goose_mode = config - .get_goose_mode() - .unwrap_or(goose::config::GooseMode::Auto); - - Self::with_config(AcpServerConfig { - provider, - builtins, - data_dir: Paths::data_dir(), - config_dir: Paths::config_dir(), - goose_mode, - disable_session_naming: config.get_goose_disable_session_naming().unwrap_or(false), - }) - .await - } - - pub async fn with_config(config: AcpServerConfig) -> Result { - let session_manager = Arc::new(SessionManager::new(config.data_dir)); - let config_dir = config.config_dir.clone(); - let permission_manager = Arc::new(PermissionManager::new(config.config_dir)); + pub async fn new( + provider_factory: ProviderConstructor, + builtins: Vec, + data_dir: std::path::PathBuf, + config_dir: std::path::PathBuf, + goose_mode: goose::config::GooseMode, + disable_session_naming: bool, + ) -> Result { + let session_manager = Arc::new(SessionManager::new(data_dir)); + let permission_manager = Arc::new(PermissionManager::new(config_dir.clone())); let agent = Agent::with_config(AgentConfig::new( Arc::clone(&session_manager), permission_manager, None, - config.goose_mode, - config.disable_session_naming, + goose_mode, + disable_session_naming, )); let agent_ptr = Arc::new(agent); @@ -334,13 +296,15 @@ impl GooseAcpAgent { let config_file = Config::new(&config_path, "goose")?; let extensions = get_enabled_extensions_with_config(&config_file); - add_builtins(&agent_ptr, config.builtins).await; + add_builtins(&agent_ptr, builtins).await; add_extensions(&agent_ptr, extensions).await; Ok(Self { - provider: config.provider.clone(), sessions: Arc::new(Mutex::new(HashMap::new())), agent: agent_ptr, + provider_factory, + config_dir, + provider_initialized: tokio::sync::OnceCell::new(), }) } @@ -354,9 +318,7 @@ impl GooseAcpAgent { ) .await?; - self.agent - .update_provider(self.provider.clone(), &goose_session.id) - .await?; + self.ensure_provider(&goose_session).await?; let session = GooseAcpSession { messages: Conversation::new_unvalidated(Vec::new()), @@ -692,7 +654,15 @@ impl GooseAcpAgent { .embedded_context(true), ) .mcp_capabilities(McpCapabilities::new().http(true)); - Ok(InitializeResponse::new(args.protocol_version).agent_capabilities(capabilities)) + Ok(InitializeResponse::new(args.protocol_version) + .agent_capabilities(capabilities) + .auth_methods(vec![AuthMethod::new( + "goose-provider", + "Configure Provider", + ) + .description( + "Run `goose configure` to set up your AI provider and API key", + )])) } async fn on_new_session( @@ -712,7 +682,9 @@ impl GooseAcpAgent { .map_err(|e| { sacp::Error::internal_error().data(format!("Failed to create session: {}", e)) })?; - self.update_session_with_provider(&goose_session).await?; + self.ensure_provider(&goose_session).await.map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) + })?; for mcp_server in args.mcp_servers { let config = match mcp_server_to_extension_config(mcp_server) { @@ -746,16 +718,21 @@ impl GooseAcpAgent { Ok(NewSessionResponse::new(SessionId::new(goose_session.id))) } - async fn update_session_with_provider( - &self, - goose_session: &Session, - ) -> Result<(), sacp::Error> { - self.agent - .update_provider(self.provider.clone(), &goose_session.id) - .await - .map_err(|e| { - sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) - })?; + // Called at most once via OnceCell; returns the model_id used. + async fn create_provider(&self, session: &Session) -> Result { + let config_path = self.config_dir.join(CONFIG_YAML_NAME); + let config = Config::new(&config_path, "goose")?; + let model_id = config.get_goose_model()?; + let model_config = goose::model::ModelConfig::new(&model_id)?; + let provider = (self.provider_factory)(model_config).await?; + self.agent.update_provider(provider, &session.id).await?; + Ok(model_id) + } + + async fn ensure_provider(&self, session: &Session) -> Result<()> { + self.provider_initialized + .get_or_try_init(|| self.create_provider(session)) + .await?; Ok(()) } @@ -773,7 +750,9 @@ impl GooseAcpAgent { sacp::Error::invalid_params() .data(format!("Failed to load session {}: {}", session_id, e)) })?; - self.update_session_with_provider(&goose_session).await?; + self.ensure_provider(&goose_session).await.map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) + })?; let conversation = goose_session.conversation.ok_or_else(|| { sacp::Error::internal_error() @@ -1045,7 +1024,13 @@ pub async fn run(builtins: Vec) -> Result<()> { let outgoing = tokio::io::stdout().compat_write(); let incoming = tokio::io::stdin().compat(); - let agent = Arc::new(GooseAcpAgent::new(builtins).await?); + let server = + crate::server_factory::AcpServer::new(crate::server_factory::AcpServerFactoryConfig { + builtins, + data_dir: Paths::data_dir(), + config_dir: Paths::config_dir(), + }); + let agent = server.create_agent().await?; serve(agent, incoming, outgoing).await } diff --git a/crates/goose-acp/src/server_factory.rs b/crates/goose-acp/src/server_factory.rs index 01511e8c4005..5f7a86e4ba09 100644 --- a/crates/goose-acp/src/server_factory.rs +++ b/crates/goose-acp/src/server_factory.rs @@ -1,12 +1,9 @@ use anyhow::Result; -use goose::config::paths::Paths; -use goose::config::Config; -use goose::model::ModelConfig; -use goose::providers::create; +use goose::providers::provider_registry::ProviderConstructor; use std::sync::Arc; use tracing::info; -use crate::server::{AcpServerConfig, GooseAcpAgent}; +use crate::server::GooseAcpAgent; pub struct AcpServerFactoryConfig { pub builtins: Vec, @@ -14,16 +11,6 @@ pub struct AcpServerFactoryConfig { pub config_dir: std::path::PathBuf, } -impl Default for AcpServerFactoryConfig { - fn default() -> Self { - Self { - builtins: vec!["developer".to_string()], - data_dir: Paths::data_dir(), - config_dir: Paths::config_dir(), - } - } -} - pub struct AcpServer { config: AcpServerFactoryConfig, } @@ -34,44 +21,39 @@ impl AcpServer { } pub async fn create_agent(&self) -> Result> { - let global_config = Config::global(); - - let provider_name: String = global_config - .get_goose_provider() - .map_err(|e| anyhow::anyhow!("No provider configured: {}", e))?; + let config_path = self + .config + .config_dir + .join(goose::config::base::CONFIG_YAML_NAME); + let config = goose::config::Config::new(&config_path, "goose")?; - let model_name: String = global_config - .get_goose_model() - .map_err(|e| anyhow::anyhow!("No model configured: {}", e))?; - - let model_config = ModelConfig { - request_params: None, - model_name: model_name.clone(), - context_limit: None, - temperature: None, - max_tokens: None, - toolshim: false, - toolshim_model: None, - fast_model: None, - }; - - let provider = create(&provider_name, model_config).await?; - let goose_mode = global_config + let goose_mode = config .get_goose_mode() .unwrap_or(goose::config::GooseMode::Auto); - - let acp_config = AcpServerConfig { - provider, - builtins: self.config.builtins.clone(), - data_dir: self.config.data_dir.clone(), - config_dir: self.config.config_dir.clone(), + let disable_session_naming = config.get_goose_disable_session_naming().unwrap_or(false); + + let config_dir = self.config.config_dir.clone(); + let provider_factory: ProviderConstructor = Arc::new(move |model_config| { + let config_dir = config_dir.clone(); + Box::pin(async move { + let config_path = config_dir.join(goose::config::base::CONFIG_YAML_NAME); + let config = goose::config::Config::new(&config_path, "goose")?; + let provider_name = config + .get_goose_provider() + .map_err(|_| anyhow::anyhow!("No provider configured"))?; + goose::providers::create(&provider_name, model_config).await + }) + }); + + let agent = GooseAcpAgent::new( + provider_factory, + self.config.builtins.clone(), + self.config.data_dir.clone(), + self.config.config_dir.clone(), goose_mode, - disable_session_naming: global_config - .get_goose_disable_session_naming() - .unwrap_or(false), - }; - - let agent = GooseAcpAgent::with_config(acp_config).await?; + disable_session_naming, + ) + .await?; info!("Created new ACP agent"); Ok(Arc::new(agent)) diff --git a/crates/goose-acp/tests/common_tests/mod.rs b/crates/goose-acp/tests/common_tests/mod.rs index 10c7cc55c411..f307211ba1ab 100644 --- a/crates/goose-acp/tests/common_tests/mod.rs +++ b/crates/goose-acp/tests/common_tests/mod.rs @@ -189,7 +189,7 @@ pub async fn run_configured_extension() { let mcp = McpFixture::new(expected_session_id.clone()).await; let config_yaml = format!( - "extensions:\n lookup:\n enabled: true\n type: streamable_http\n name: lookup\n description: Lookup server\n uri: \"{}\"\n", + "GOOSE_MODEL: gpt-5-nano\nextensions:\n lookup:\n enabled: true\n type: streamable_http\n name: lookup\n description: Lookup server\n uri: \"{}\"\n", mcp.url ); fs::write(temp_dir.path().join(CONFIG_YAML_NAME), config_yaml).unwrap(); diff --git a/crates/goose-acp/tests/fixtures/mod.rs b/crates/goose-acp/tests/fixtures/mod.rs index 33acbb49ccf6..3aa74a9dfcca 100644 --- a/crates/goose-acp/tests/fixtures/mod.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -5,11 +5,12 @@ use async_trait::async_trait; use fs_err as fs; use goose::builtin_extension::register_builtin_extensions; use goose::config::{GooseMode, PermissionManager}; -use goose::model::ModelConfig; use goose::providers::api_client::{ApiClient, AuthMethod}; +use goose::providers::base::Provider; use goose::providers::openai::OpenAiProvider; +use goose::providers::provider_registry::ProviderConstructor; use goose::session_context::SESSION_ID_HEADER; -use goose_acp::server::{serve, AcpServerConfig, GooseAcpAgent}; +use goose_acp::server::{serve, GooseAcpAgent}; use rmcp::model::{ClientNotification, ClientRequest, Meta, ServerResult}; use rmcp::service::{NotificationContext, RequestContext, ServiceRole}; use rmcp::transport::streamable_http_server::{ @@ -359,48 +360,71 @@ impl McpFixture { } } -#[allow(dead_code)] -pub async fn spawn_acp_server_in_process( - openai_base_url: &str, - builtins: &[String], - data_root: &Path, - goose_mode: GooseMode, -) -> ( - tokio::io::DuplexStream, - tokio::io::DuplexStream, - JoinHandle<()>, - Arc, -) { - fs::create_dir_all(data_root).unwrap(); - let api_client = ApiClient::new( - openai_base_url.to_string(), - AuthMethod::BearerToken("test-key".to_string()), - ) - .unwrap(); - let model_config = ModelConfig::new("gpt-5-nano").unwrap(); - let provider = OpenAiProvider::new(api_client, model_config); - - let config = AcpServerConfig { - provider: Arc::new(provider), - builtins: builtins.to_vec(), - data_dir: data_root.to_path_buf(), - config_dir: data_root.to_path_buf(), - goose_mode, - disable_session_naming: true, - }; +pub type DuplexTransport = sacp::ByteStreams< + tokio_util::compat::Compat, + tokio_util::compat::Compat, +>; +/// Wires up duplex streams, spawns `serve` for the given agent, and returns +/// a ready-to-use sacp transport plus the server handle. +#[allow(dead_code)] +pub async fn serve_agent_in_process( + agent: Arc, +) -> (DuplexTransport, JoinHandle<()>) { let (client_read, server_write) = tokio::io::duplex(64 * 1024); let (server_read, client_write) = tokio::io::duplex(64 * 1024); - let agent = Arc::new(GooseAcpAgent::with_config(config).await.unwrap()); - let permission_manager = agent.permission_manager(); let handle = tokio::spawn(async move { if let Err(e) = serve(agent, server_read.compat(), server_write.compat_write()).await { tracing::error!("ACP server error: {e}"); } }); - (client_read, client_write, handle, permission_manager) + let transport = sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); + (transport, handle) +} + +#[allow(dead_code)] +pub async fn spawn_acp_server_in_process( + openai_base_url: &str, + builtins: &[String], + data_root: &Path, + goose_mode: GooseMode, +) -> (DuplexTransport, JoinHandle<()>, Arc) { + fs::create_dir_all(data_root).unwrap(); + // ensure_provider reads the model from config lazily, so tests need a config.yaml. + let config_path = data_root.join(goose::config::base::CONFIG_YAML_NAME); + if !config_path.exists() { + fs::write(&config_path, "GOOSE_MODEL: gpt-5-nano\n").unwrap(); + } + let base_url = openai_base_url.to_string(); + let provider_factory: ProviderConstructor = Arc::new(move |model_config| { + let base_url = base_url.clone(); + Box::pin(async move { + let api_client = + ApiClient::new(base_url, AuthMethod::BearerToken("test-key".to_string())).unwrap(); + let provider: Arc = + Arc::new(OpenAiProvider::new(api_client, model_config)); + Ok(provider) + }) + }); + + let agent = Arc::new( + GooseAcpAgent::new( + provider_factory, + builtins.to_vec(), + data_root.to_path_buf(), + data_root.to_path_buf(), + goose_mode, + true, + ) + .await + .unwrap(), + ); + let permission_manager = agent.permission_manager(); + let (transport, handle) = serve_agent_in_process(agent).await; + + (transport, handle, permission_manager) } pub struct TestOutput { @@ -463,4 +487,26 @@ where } } +/// Connects to the given agent via in-process duplex streams, sends an +/// `InitializeRequest`, and returns the response. +#[allow(dead_code)] +pub async fn initialize_agent(agent: Arc) -> sacp::schema::InitializeResponse { + let (transport, _handle) = serve_agent_in_process(agent).await; + sacp::ClientToAgent::builder() + .connect_to(transport) + .unwrap() + .run_until(|cx: sacp::JrConnectionCx| async move { + let resp = cx + .send_request(sacp::schema::InitializeRequest::new( + sacp::schema::ProtocolVersion::LATEST, + )) + .block_task() + .await + .unwrap(); + Ok::<_, sacp::Error>(resp) + }) + .await + .unwrap() +} + pub mod server; diff --git a/crates/goose-acp/tests/fixtures/server.rs b/crates/goose-acp/tests/fixtures/server.rs index bc81f2b2662c..572ebe337c54 100644 --- a/crates/goose-acp/tests/fixtures/server.rs +++ b/crates/goose-acp/tests/fixtures/server.rs @@ -13,7 +13,6 @@ use sacp::{ClientToAgent, JrConnectionCx}; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::sync::Notify; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; pub struct ClientToAgentSession { cx: JrConnectionCx, @@ -39,7 +38,7 @@ impl Session for ClientToAgentSession { false => (config.data_root.clone(), None), }; - let (client_read, client_write, _handle, permission_manager) = spawn_acp_server_in_process( + let (transport, _handle, permission_manager) = spawn_acp_server_in_process( openai.uri(), &config.builtins, data_root.as_path(), @@ -51,8 +50,6 @@ impl Session for ClientToAgentSession { let notify = Arc::new(Notify::new()); let permission = Arc::new(Mutex::new(PermissionDecision::Cancel)); - let transport = sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); - let (cx, session_id) = { let updates_clone = updates.clone(); let notify_clone = notify.clone(); diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs index c9241b2a90eb..5345a6148c66 100644 --- a/crates/goose-acp/tests/server_test.rs +++ b/crates/goose-acp/tests/server_test.rs @@ -1,10 +1,15 @@ mod common_tests; +use common_tests::fixtures::initialize_agent; use common_tests::fixtures::run_test; use common_tests::fixtures::server::ClientToAgentSession; use common_tests::{ run_basic_completion, run_builtin_and_mcp, run_configured_extension, run_mcp_http_server, run_permission_persistence, }; +use goose::config::GooseMode; +use goose::providers::provider_registry::ProviderConstructor; +use goose_acp::server::GooseAcpAgent; +use std::sync::Arc; #[test] fn test_acp_basic_completion() { @@ -30,3 +35,34 @@ fn test_permission_persistence() { fn test_configured_extension() { run_test(async { run_configured_extension::().await }); } + +#[test] +fn test_initialize_without_provider() { + run_test(async { + let temp_dir = tempfile::tempdir().unwrap(); + + let provider_factory: ProviderConstructor = + Arc::new(|_| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) })); + + let agent = Arc::new( + GooseAcpAgent::new( + provider_factory, + vec![], + temp_dir.path().to_path_buf(), + temp_dir.path().to_path_buf(), + GooseMode::Auto, + false, + ) + .await + .unwrap(), + ); + + // Initialization shouldn't fail even though we have a crashing provider factory. + let resp = initialize_agent(agent).await; + assert!(!resp.auth_methods.is_empty()); + assert!(resp + .auth_methods + .iter() + .any(|m| &*m.id.0 == "goose-provider")); + }); +} diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index b5a25b0f6578..542b0a627345 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -6,7 +6,7 @@ use futures::future::BoxFuture; use std::collections::HashMap; use std::sync::Arc; -type ProviderConstructor = +pub type ProviderConstructor = Arc BoxFuture<'static, Result>> + Send + Sync>; #[derive(Clone)]