diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 5832330bb764..36066c24d414 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -19,7 +19,7 @@ use crate::commands::session::{handle_session_list, handle_session_remove}; use crate::logging::setup_logging; use crate::recipes::recipe::{explain_recipe_with_parameters, load_recipe_as_template}; use crate::session; -use crate::session::{build_session, SessionBuilderConfig}; +use crate::session::{build_session, SessionBuilderConfig, SessionSettings}; use goose_bench::bench_config::BenchRunConfig; use goose_bench::runners::bench_runner::BenchRunner; use goose_bench::runners::eval_runner::EvalRunner; @@ -552,6 +552,7 @@ enum CliProviderVariant { Ollama, } +#[derive(Debug)] struct InputConfig { contents: Option, extensions_override: Option>, @@ -630,6 +631,7 @@ pub async fn cli() -> Result<()> { builtins, extensions_override: None, additional_system_prompt: None, + settings: None, debug, max_tool_repetitions, interactive: true, // Session command is always interactive @@ -676,18 +678,22 @@ pub async fn cli() -> Result<()> { params, explain, }) => { - let input_config = match (instructions, input_text, recipe, explain) { + let (input_config, session_settings) = match (instructions, input_text, recipe, explain) + { (Some(file), _, _, _) if file == "-" => { let mut input = String::new(); std::io::stdin() .read_to_string(&mut input) .expect("Failed to read from stdin"); - InputConfig { - contents: Some(input), - extensions_override: None, - additional_system_prompt: None, - } + ( + InputConfig { + contents: Some(input), + extensions_override: None, + additional_system_prompt: None, + }, + None, + ) } (Some(file), _, _, _) => { let contents = std::fs::read_to_string(&file).unwrap_or_else(|err| { @@ -697,17 +703,23 @@ pub async fn cli() -> Result<()> { ); std::process::exit(1); }); + ( + InputConfig { + contents: Some(contents), + extensions_override: None, + additional_system_prompt: None, + }, + None, + ) + } + (_, Some(text), _, _) => ( InputConfig { - contents: Some(contents), + contents: Some(text), extensions_override: None, additional_system_prompt: None, - } - } - (_, Some(text), _, _) => InputConfig { - contents: Some(text), - extensions_override: None, - additional_system_prompt: None, - }, + }, + None, + ), (_, _, Some(recipe_name), explain) => { if explain { explain_recipe_with_parameters(&recipe_name, params)?; @@ -718,11 +730,18 @@ pub async fn cli() -> Result<()> { eprintln!("{}: {}", console::style("Error").red().bold(), err); std::process::exit(1); }); - InputConfig { - contents: recipe.prompt, - extensions_override: recipe.extensions, - additional_system_prompt: recipe.instructions, - } + ( + InputConfig { + contents: recipe.prompt, + extensions_override: recipe.extensions, + additional_system_prompt: recipe.instructions, + }, + recipe.settings.map(|s| SessionSettings { + goose_provider: s.goose_provider, + goose_model: s.goose_model, + temperature: s.temperature, + }), + ) } (None, None, None, _) => { eprintln!("Error: Must provide either --instructions (-i), --text (-t), or --recipe. Use -i - for stdin."); @@ -739,6 +758,7 @@ pub async fn cli() -> Result<()> { builtins, extensions_override: input_config.extensions_override, additional_system_prompt: input_config.additional_system_prompt, + settings: session_settings, debug, max_tool_repetitions, interactive, // Use the interactive flag from the Run command @@ -854,6 +874,7 @@ pub async fn cli() -> Result<()> { builtins: Vec::new(), extensions_override: None, additional_system_prompt: None, + settings: None::, debug: false, max_tool_repetitions: None, interactive: true, // Default case is always interactive diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index 1d2488334b87..01cbd6a037ae 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -40,6 +40,7 @@ pub async fn agent_generator( builtins: requirements.builtin, extensions_override: None, additional_system_prompt: None, + settings: None, debug: false, max_tool_repetitions: None, interactive: false, // Benchmarking is non-interactive diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 3c314b745432..adc8c4915f45 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -34,6 +34,8 @@ pub struct SessionBuilderConfig { pub extensions_override: Option>, /// Any additional system prompt to append to the default pub additional_system_prompt: Option, + /// Settings to override the global Goose settings + pub settings: Option, /// Enable debug printing pub debug: bool, /// Maximum number of consecutive identical tool calls allowed @@ -136,18 +138,35 @@ async fn offer_extension_debugging_help( Ok(()) } +#[derive(Clone, Debug, Default)] +pub struct SessionSettings { + pub goose_model: Option, + pub goose_provider: Option, + pub temperature: Option, +} + pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Load config and get provider/model let config = Config::global(); - let provider_name: String = config - .get_param("GOOSE_PROVIDER") + let provider_name = session_config + .settings + .as_ref() + .and_then(|s| s.goose_provider.clone()) + .or_else(|| config.get_param("GOOSE_PROVIDER").ok()) .expect("No provider configured. Run 'goose configure' first"); - let model: String = config - .get_param("GOOSE_MODEL") + let model_name = session_config + .settings + .as_ref() + .and_then(|s| s.goose_model.clone()) + .or_else(|| config.get_param("GOOSE_MODEL").ok()) .expect("No model configured. Run 'goose configure' first"); - let model_config = goose::model::ModelConfig::new(model.clone()); + + let temperature = session_config.settings.as_ref().and_then(|s| s.temperature); + + let model_config = + goose::model::ModelConfig::new(model_name.clone()).with_temperature(temperature); // Create the agent let agent: Agent = Agent::new(); @@ -165,7 +184,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { worker_model ); } else { - tracing::info!("🤖 Using model: {}", model); + tracing::info!("🤖 Using model: {}", model_name); } agent @@ -430,7 +449,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { output::display_session_info( session_config.resume, &provider_name, - &model, + &model_name, &session_file, Some(&provider_for_display), ); @@ -452,6 +471,7 @@ mod tests { builtins: vec!["developer".to_string()], extensions_override: None, additional_system_prompt: Some("Test prompt".to_string()), + settings: None, debug: true, max_tool_repetitions: Some(5), interactive: true, diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 6bdba33117c0..8eeec0ea22e5 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -7,7 +7,7 @@ mod prompt; mod thinking; pub use self::export::message_to_markdown; -pub use builder::{build_session, SessionBuilderConfig}; +pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; use goose::agents::AgentEvent; use goose::permission::permission_confirmation::PrincipalType; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 991603007c87..da15ca8d3d04 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -16,7 +16,7 @@ use crate::permission::permission_judge::check_tool_permissions; use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; -use crate::recipe::{Author, Recipe}; +use crate::recipe::{Author, Recipe, Settings}; use crate::tool_monitor::{ToolCall, ToolMonitor}; use regex::Regex; use serde_json::Value; @@ -973,12 +973,26 @@ impl Agent { metadata: None, }; + // Ideally we'd get the name of the provider we are using from the provider itself + // but it doesn't know and the plumbing looks complicated. + let config = Config::global(); + let provider_name: String = config + .get_param("GOOSE_PROVIDER") + .expect("No provider configured. Run 'goose configure' first"); + + let settings = Settings { + goose_provider: Some(provider_name.clone()), + goose_model: Some(model_name.clone()), + temperature: Some(model_config.temperature.unwrap_or(0.0)), + }; + let recipe = Recipe::builder() .title("Custom recipe from chat") .description("a custom recipe instance from this chat session") .instructions(instructions) .activities(activities) .extensions(extension_configs) + .settings(settings) .author(author) .build() .expect("valid recipe"); diff --git a/crates/goose/src/agents/prompt_manager.rs b/crates/goose/src/agents/prompt_manager.rs index cdd367c9ed9b..a5a4e41ff962 100644 --- a/crates/goose/src/agents/prompt_manager.rs +++ b/crates/goose/src/agents/prompt_manager.rs @@ -154,7 +154,6 @@ impl PromptManager { } } - /// Get the recipe prompt pub async fn get_recipe_prompt(&self) -> String { let context: HashMap<&str, Value> = HashMap::new(); prompt_template::render_global_file("recipe.md", &context).expect("Prompt should render") diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 510ba000a02c..06a1bc8bb67e 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -50,6 +50,7 @@ fn default_version() -> String { /// context: None, /// activities: None, /// author: None, +/// settings: None, /// parameters: None, /// }; /// @@ -77,6 +78,9 @@ pub struct Recipe { #[serde(skip_serializing_if = "Option::is_none")] pub context: Option>, // any additional context + #[serde(skip_serializing_if = "Option::is_none")] + pub settings: Option, // settings for the recipe + #[serde(skip_serializing_if = "Option::is_none")] pub activities: Option>, // the activity pills that show up when loading the @@ -96,6 +100,18 @@ pub struct Author { pub metadata: Option, // any additional metadata for the author } +#[derive(Serialize, Deserialize, Debug)] +pub struct Settings { + #[serde(skip_serializing_if = "Option::is_none")] + pub goose_provider: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub goose_model: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "snake_case")] pub enum RecipeParameterRequirement { @@ -156,6 +172,7 @@ pub struct RecipeBuilder { prompt: Option, extensions: Option>, context: Option>, + settings: Option, activities: Option>, author: Option, parameters: Option>, @@ -185,6 +202,7 @@ impl Recipe { prompt: None, extensions: None, context: None, + settings: None, activities: None, author: None, parameters: None, @@ -234,6 +252,11 @@ impl RecipeBuilder { self } + pub fn settings(mut self, settings: Settings) -> Self { + self.settings = Some(settings); + self + } + /// Sets the activities for the Recipe pub fn activities(mut self, activities: Vec) -> Self { self.activities = Some(activities); @@ -271,6 +294,7 @@ impl RecipeBuilder { prompt: self.prompt, extensions: self.extensions, context: self.context, + settings: self.settings, activities: self.activities, author: self.author, parameters: self.parameters, diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 275139bf1d3c..1b4871183c2d 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1300,6 +1300,7 @@ mod tests { activities: None, author: None, parameters: None, + settings: None, }; let mut recipe_file = File::create(&recipe_filename)?; writeln!(