Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions crates/goose-acp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ impl GooseAcpAgent {
goose::model::ModelConfig::new(&model_id)?
}
};
let provider = (self.provider_factory)(model_config).await?;
let provider = (self.provider_factory)(model_config, Vec::new()).await?;
agent.update_provider(provider.clone(), &session.id).await?;
Ok(provider)
}
Expand Down Expand Up @@ -958,9 +958,11 @@ impl GooseAcpAgent {
let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| {
sacp::Error::invalid_params().data(format!("Invalid model config: {}", e))
})?;
let provider = (self.provider_factory)(model_config).await.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?;
let provider = (self.provider_factory)(model_config, Vec::new())
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?;

let agent = {
let sessions = self.sessions.lock().await;
Expand Down
4 changes: 2 additions & 2 deletions crates/goose-acp/src/server_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ impl AcpServer {
let disable_session_naming = config.get_goose_disable_session_naming().unwrap_or(false);

let config_dir = self.config.config_dir.clone();
let provider_factory: ProviderConstructor = Arc::new(move |model_config| {
let provider_factory: ProviderConstructor = Arc::new(move |model_config, extensions| {
let config_dir = config_dir.clone();
Box::pin(async move {
let config_path = config_dir.join(goose::config::base::CONFIG_YAML_NAME);
let config = goose::config::Config::new(&config_path, "goose")?;
let provider_name = config
.get_goose_provider()
.map_err(|_| anyhow::anyhow!("No provider configured"))?;
goose::providers::create(&provider_name, model_config).await
goose::providers::create(&provider_name, model_config, extensions).await
})
});

Expand Down
2 changes: 1 addition & 1 deletion crates/goose-acp/tests/common_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub async fn run_initialize_without_provider() {
let temp_dir = tempfile::tempdir().unwrap();

let provider_factory: ProviderConstructor =
Arc::new(|_| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) }));
Arc::new(|_, _| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) }));

let agent = Arc::new(
GooseAcpAgent::new(
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-acp/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ pub async fn spawn_acp_server_in_process(
}
let provider_factory = provider_factory.unwrap_or_else(|| {
let base_url = openai_base_url.to_string();
Arc::new(move |model_config| {
Arc::new(move |model_config, _extensions| {
let base_url = base_url.clone();
Box::pin(async move {
let api_client =
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ indicatif = "0.18.1"
tokio-util = { workspace = true, features = ["compat", "rt"] }
anstream = "0.6.18"
open = "5.3.2"
url = { workspace = true }
urlencoding = { workspace = true }
clap_complete = "4.5.62"

Expand All @@ -70,4 +71,5 @@ disable-update = []

[dev-dependencies]
tempfile = { workspace = true }
test-case = { workspace = true }
tokio = { workspace = true }
17 changes: 10 additions & 7 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async fn handle_oauth_configuration(provider_name: &str, key_name: &str) -> anyh

// Create a temporary provider instance to handle OAuth
let temp_model = ModelConfig::new("temp")?;
match create(provider_name, temp_model).await {
match create(provider_name, temp_model, Vec::new()).await {
Ok(provider) => match provider.configure_oauth().await {
Ok(_) => {
let _ = cliclack::log::success("OAuth authentication completed successfully!");
Expand Down Expand Up @@ -683,7 +683,7 @@ pub async fn configure_provider_dialog() -> anyhow::Result<bool> {
spin.start("Attempting to fetch supported models...");
let models_res = {
let temp_model_config = ModelConfig::new(&provider_meta.default_model)?;
let temp_provider = create(provider_name, temp_model_config).await?;
let temp_provider = create(provider_name, temp_model_config, Vec::new()).await?;
retry_operation(&RetryConfig::default(), || async {
temp_provider.fetch_recommended_models().await
})
Expand Down Expand Up @@ -1445,7 +1445,6 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> {
let model_config = ModelConfig::new(&model)?;

let agent = Agent::new();
let new_provider = create(&provider_name, model_config).await?;

let session = agent
.config
Expand All @@ -1457,8 +1456,8 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> {
)
.await?;

agent.update_provider(new_provider, &session.id).await?;
if let Some(config) = get_extension_by_name(&selected_extension_name) {
let extension_config = get_extension_by_name(&selected_extension_name);
if let Some(config) = extension_config.as_ref() {
agent
.add_extension(config.clone(), &session.id)
.await
Expand All @@ -1478,6 +1477,10 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> {
return Ok(());
}

let extensions = extension_config.into_iter().collect::<Vec<_>>();
let new_provider = create(&provider_name, model_config, extensions).await?;
agent.update_provider(new_provider, &session.id).await?;

let permission_manager = PermissionManager::instance();
let selected_tools = agent
.list_tools(&session.id, Some(selected_extension_name.clone()))
Expand Down Expand Up @@ -1667,7 +1670,7 @@ pub async fn handle_openrouter_auth() -> anyhow::Result<()> {
}
};

match create("openrouter", model_config).await {
match create("openrouter", model_config, Vec::new()).await {
Ok(provider) => {
let model_config = provider.get_model_config();
let test_result = provider
Expand Down Expand Up @@ -1747,7 +1750,7 @@ pub async fn handle_tetrate_auth() -> anyhow::Result<()> {
}
};

match create("tetrate", model_config).await {
match create("tetrate", model_config, Vec::new()).await {
Ok(provider) => {
let test_result = provider.fetch_supported_models().await;

Expand Down
8 changes: 4 additions & 4 deletions crates/goose-cli/src/commands/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ async fn create_agent(provider_name: &str, model: &str) -> Result<Agent> {
)
.await?;

let provider = goose::providers::create(provider_name, model_config).await?;
agent.update_provider(provider, &init_session.id).await?;

let enabled_configs = goose::config::get_enabled_extensions();
for config in enabled_configs {
for config in &enabled_configs {
if let Err(e) = agent.add_extension(config.clone(), &init_session.id).await {
eprintln!("Warning: Failed to load extension {}: {}", config.name(), e);
}
}

let provider = goose::providers::create(provider_name, model_config, enabled_configs).await?;
agent.update_provider(provider, &init_session.id).await?;

Ok(agent)
}

Expand Down
7 changes: 6 additions & 1 deletion crates/goose-cli/src/scenario_tests/scenario_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ where

let original_env = setup_environment(config)?;

let inner_provider = create(&factory_name, ModelConfig::new(config.model_name)?).await?;
let inner_provider = create(
&factory_name,
ModelConfig::new(config.model_name)?,
Vec::new(),
)
.await?;

let test_provider = Arc::new(TestProvider::new_recording(inner_provider, &file_path));
(
Expand Down
92 changes: 57 additions & 35 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ use crate::cli::StreamableHttpOptions;
use super::output;
use super::CliSession;
use console::style;
use goose::agents::{Agent, Container};
use goose::config::get_enabled_extensions;
use goose::agents::{Agent, Container, ExtensionError};
use goose::config::resolve_extensions_for_new_session;
use goose::config::{get_all_extensions, Config, ExtensionConfig};
use goose::providers::create;
use goose::recipe::Recipe;
use goose::session::session_manager::SessionType;
use goose::session::{EnabledExtensionsState, ExtensionState};
use goose::session::EnabledExtensionsState;
use rustyline::EditMode;
use std::collections::BTreeSet;
use std::process;
Expand Down Expand Up @@ -490,27 +489,19 @@ async fn handle_resumed_session_workdir(agent: &Agent, session_id: &str, interac
}
}

async fn resolve_and_load_extensions(
agent: Agent,
async fn collect_extension_configs(
agent: &Agent,
session_config: &SessionBuilderConfig,
recipe: Option<&Recipe>,
session_id: &str,
provider_for_debug: Arc<dyn goose::providers::base::Provider>,
) -> Arc<Agent> {
for warning in goose::config::get_warnings() {
eprintln!("{}", style(format!("Warning: {}", warning)).yellow());
}

) -> Result<Vec<ExtensionConfig>, ExtensionError> {
let configured_extensions: Vec<ExtensionConfig> = if session_config.resume {
agent
.config
.session_manager
.get_session(session_id, false)
.await
.ok()
.and_then(|s| EnabledExtensionsState::from_extension_data(&s.extension_data))
.map(|state| state.extensions)
.unwrap_or_else(get_enabled_extensions)
EnabledExtensionsState::for_session(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to avoid so many copy/paste

&agent.config.session_manager,
session_id,
Config::global(),
)
.await
} else if session_config.no_profile {
Vec::new()
} else {
Expand All @@ -523,17 +514,33 @@ async fn resolve_and_load_extensions(
&session_config.builtins,
);

let mut extensions_to_load: Vec<(String, ExtensionConfig)> = configured_extensions
.iter()
.map(|cfg| (cfg.name(), cfg.clone()))
let mut all: Vec<ExtensionConfig> = configured_extensions;
all.extend(cli_flag_extensions.into_iter().map(|(_, cfg)| cfg));

Ok(all)
}

async fn resolve_and_load_extensions(
agent: Agent,
extensions: Vec<ExtensionConfig>,
provider_for_debug: Arc<dyn goose::providers::base::Provider>,
interactive: bool,
session_id: &str,
) -> Arc<Agent> {
for warning in goose::config::get_warnings() {
eprintln!("{}", style(format!("Warning: {}", warning)).yellow());
}

let extensions_to_load: Vec<(String, ExtensionConfig)> = extensions
.into_iter()
.map(|cfg| (cfg.name(), cfg))
.collect();
extensions_to_load.extend(cli_flag_extensions);

load_extensions(
agent,
extensions_to_load,
provider_for_debug,
session_config.interactive,
interactive,
session_id,
)
.await
Expand Down Expand Up @@ -598,7 +605,28 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
.apply_recipe_components(recipe.and_then(|r| r.response.clone()), true)
.await;

let new_provider = match create(&resolved.provider_name, resolved.model_config).await {
let session_id = resolve_session_id(&session_config, &session_manager).await;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering here is necessary to resolve the extensions before we init the provider.


if session_config.resume {
handle_resumed_session_workdir(&agent, &session_id, session_config.interactive).await;
}

let extensions_for_provider =
match collect_extension_configs(&agent, &session_config, recipe, &session_id).await {
Ok(exts) => exts,
Err(e) => {
output::render_error(&format!("Failed to collect extensions: {}", e));
process::exit(1);
}
};

let new_provider = match create(
&resolved.provider_name,
resolved.model_config,
extensions_for_provider.clone(),
)
.await
{
Ok(provider) => provider,
Err(e) => {
output::render_error(&format!(
Expand All @@ -624,8 +652,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
tracing::info!("πŸ€– Using model: {}", resolved.model_name);
}

let session_id = resolve_session_id(&session_config, &session_manager).await;

agent
.update_provider(new_provider, &session_id)
.await
Expand All @@ -645,17 +671,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
}
}

if session_config.resume {
handle_resumed_session_workdir(&agent, &session_id, session_config.interactive).await;
}

// Extensions are loaded after session creation because we may change directory when resuming
let agent_ptr = resolve_and_load_extensions(
agent,
&session_config,
recipe,
&session_id,
extensions_for_provider,
Arc::clone(&provider_for_display),
session_config.interactive,
&session_id,
)
.await;

Expand Down
Loading
Loading