From d409f69ad2b214daf5ce129faff98230c8314092 Mon Sep 17 00:00:00 2001 From: Douwe Osinga Date: Tue, 29 Jul 2025 22:32:39 +0200 Subject: [PATCH 1/3] Ok, well, that got out of hand --- .../goose/src/agents/router_tool_selector.rs | 3 +- crates/goose/src/context_mgmt/summarize.rs | 15 +- crates/goose/src/lib.rs | 2 + crates/goose/src/macros.rs | 19 + crates/goose/src/model.rs | 435 +++++++++--------- .../goose/src/permission/permission_judge.rs | 4 +- crates/goose/src/providers/anthropic.rs | 8 +- crates/goose/src/providers/azure.rs | 8 +- crates/goose/src/providers/base.rs | 3 +- crates/goose/src/providers/bedrock.rs | 19 +- crates/goose/src/providers/claude_code.rs | 13 +- crates/goose/src/providers/databricks.rs | 8 +- crates/goose/src/providers/factory.rs | 19 +- .../goose/src/providers/formats/anthropic.rs | 6 +- .../goose/src/providers/formats/snowflake.rs | 4 +- crates/goose/src/providers/gcpvertexai.rs | 10 +- crates/goose/src/providers/gemini_cli.rs | 12 +- crates/goose/src/providers/githubcopilot.rs | 8 +- crates/goose/src/providers/google.rs | 8 +- crates/goose/src/providers/groq.rs | 8 +- crates/goose/src/providers/lead_worker.rs | 12 +- crates/goose/src/providers/litellm.rs | 8 +- crates/goose/src/providers/ollama.rs | 8 +- crates/goose/src/providers/openai.rs | 8 +- crates/goose/src/providers/openrouter.rs | 8 +- crates/goose/src/providers/sagemaker_tgi.rs | 8 +- crates/goose/src/providers/snowflake.rs | 8 +- crates/goose/src/providers/testprovider.rs | 4 +- crates/goose/src/providers/toolshim.rs | 3 +- crates/goose/src/providers/venice.rs | 8 +- crates/goose/src/providers/xai.rs | 8 +- crates/goose/src/recipe/mod.rs | 1 - crates/goose/src/scheduler.rs | 10 +- 33 files changed, 339 insertions(+), 367 deletions(-) create mode 100644 crates/goose/src/macros.rs diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index a75c648d844f..706144af245e 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -51,7 +51,8 @@ impl VectorToolSelector { env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string()); // Create the provider using the factory - let model_config = ModelConfig::new(embedding_model); + let model_config = ModelConfig::new(embedding_model.as_str()) + .context("Failed to create model config for embedding provider")?; providers::create(&embedding_provider_name, model_config).context(format!( "Failed to create {} provider for embeddings. If using OpenAI, make sure OPENAI_API_KEY env var is set or that you have configured the OpenAI provider via Goose before.", embedding_provider_name diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index cd09f4fcf726..ef798a2d2df3 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -264,12 +264,13 @@ mod tests { } } - fn create_mock_provider() -> Arc { + fn create_mock_provider() -> Result> { let mock_model_config = - ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into()); - Arc::new(MockProvider { + ModelConfig::new_or_fail("test-model").with_context_limit(200_000.into()); + + Ok(Arc::new(MockProvider { model_config: mock_model_config, - }) + })) } fn create_test_messages() -> Vec { @@ -305,7 +306,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_single_chunk() { - let provider = create_mock_provider(); + let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); let context_limit = 100; // Set a high enough limit to avoid chunking. let messages = create_test_messages(); @@ -341,7 +342,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_multiple_chunks() { - let provider = create_mock_provider(); + let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); let context_limit = 30; let messages = create_test_messages(); @@ -377,7 +378,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_empty_input() { - let provider = create_mock_provider(); + let provider = create_mock_provider().expect("failed to create mock provider"); let token_counter = TokenCounter::new(); let context_limit = 100; let messages: Vec = Vec::new(); diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index f6820d0fee4f..63f8f0aab3b8 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -22,3 +22,5 @@ pub mod utils; #[cfg(test)] mod cron_test; +#[macro_use] +mod macros; diff --git a/crates/goose/src/macros.rs b/crates/goose/src/macros.rs new file mode 100644 index 000000000000..01d853dec0ad --- /dev/null +++ b/crates/goose/src/macros.rs @@ -0,0 +1,19 @@ +#[macro_export] +macro_rules! impl_provider_default { + ($provider:ty) => { + impl Default for $provider { + fn default() -> Self { + let model = $crate::model::ModelConfig::new( + &<$provider as $crate::providers::base::Provider>::metadata().default_model, + ) + .expect(concat!( + "Failed to create model config for ", + stringify!($provider) + )); + + <$provider>::from_env(model) + .expect(concat!("Failed to initialize ", stringify!($provider))) + } + } + }; +} diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 7c1096103c53..b41fd5195a38 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -1,64 +1,78 @@ use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use thiserror::Error; const DEFAULT_CONTEXT_LIMIT: usize = 128_000; -// Define the model limits as a static HashMap for reuse -static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| { - let mut map = HashMap::new(); - // OpenAI models, https://platform.openai.com/docs/models#models-overview - map.insert("gpt-4o", 128_000); - map.insert("gpt-4-turbo", 128_000); - map.insert("o3", 200_000); - map.insert("o3-mini", 200_000); - map.insert("o4-mini", 200_000); - map.insert("gpt-4.1", 1_000_000); - map.insert("gpt-4-1", 1_000_000); - - // Anthropic models, https://docs.anthropic.com/en/docs/about-claude/models - map.insert("claude", 200_000); - - // Google models, https://ai.google/get-started/our-models/ - map.insert("gemini-2.5", 1_000_000); - map.insert("gemini-2-5", 1_000_000); - - // Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1 - map.insert("llama3.2", 128_000); - map.insert("llama3.3", 128_000); - - // x.ai Grok models, https://docs.x.ai/docs/overview - map.insert("grok", 131_072); - - // Groq models, https://console.groq.com/docs/models - map.insert("gemma2-9b", 8_192); - map.insert("kimi-k2", 131_072); - map.insert("qwen3-32b", 131_072); - map.insert("grok-3", 131_072); - map.insert("grok-4", 256_000); // 256K - map.insert("qwen3-coder", 262_144); // 262K - - map +#[derive(Error, Debug)] +pub enum ConfigError { + #[error("Environment variable '{0}' not found")] + EnvVarMissing(String), + #[error("Invalid value for '{0}': '{1}' - {2}")] + InvalidValue(String, String, String), + #[error("Value for '{0}' is out of valid range: {1}")] + InvalidRange(String, String), +} + +static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| { + vec![ + // openai + ("gpt-4-turbo", 128_000), + ("gpt-4.1", 1_000_000), + ("gpt-4-1", 1_000_000), + ("gpt-4o", 128_000), + ("o4-mini", 200_000), + ("o3-mini", 200_000), + ("o3", 200_000), + // anthropic - all 200k + ("claude", 200_000), + // google + ("gemini-1", 128_000), + ("gemini-2", 1_000_000), + ("gemma-3-27b", 128_000), + ("gemma-3-12b", 128_000), + ("gemma-3-4b", 128_000), + ("gemma-3-1b", 32_000), + ("gemma3-27b", 128_000), + ("gemma3-12b", 128_000), + ("gemma3-4b", 128_000), + ("gemma3-1b", 32_000), + ("gemma-2-27b", 8_192), + ("gemma-2-9b", 8_192), + ("gemma-2-2b", 8_192), + ("gemma2-", 8_192), + ("gemma-7b", 8_192), + ("gemma-2b", 8_192), + ("gemma1", 8_192), + ("gemma", 8_192), + // facebook + ("llama-2-1b", 32_000), + ("llama", 128_000), + // qwen + ("qwen3-coder", 262_144), + ("qwen2-7b", 128_000), + ("qwen2-14b", 128_000), + ("qwen2-32b", 131_072), + ("qwen2-70b", 262_144), + ("qwen2", 128_000), + ("qwen3-32b", 131_072), + // other + ("kimi-k2", 131_072), + ("grok-4", 256_000), + ("grok", 131_072), + ] }); -/// Configuration for model-specific settings and limits #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelConfig { - /// The name of the model to use pub model_name: String, - /// Optional explicit context limit that overrides any defaults pub context_limit: Option, - /// Optional temperature setting (0.0 - 1.0) pub temperature: Option, - /// Optional maximum tokens to generate pub max_tokens: Option, - /// Whether to interpret tool calls with toolshim pub toolshim: bool, - /// Model to use for toolshim (optional as a default exists) pub toolshim_model: Option, } -/// Struct to represent model pattern matches and their limits #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelLimitConfig { pub pattern: String, @@ -66,258 +80,267 @@ pub struct ModelLimitConfig { } impl ModelConfig { - /// Create a new ModelConfig with the specified model name - /// - /// The context limit is set with the following precedence: - /// 1. Explicit context_limit if provided in config - /// 2. Environment variable override (GOOSE_CONTEXT_LIMIT) - /// 3. Model-specific default based on model name - /// 4. Global default (128_000) (in get_context_limit) - pub fn new(model_name: String) -> Self { - Self::new_with_context_env(model_name, None) + pub fn new(model_name: &str) -> Result { + Self::new_with_context_env(model_name.to_string(), None) } - /// Create a new ModelConfig with the specified model name and custom context limit env var - /// - /// This is useful for specific model purposes like lead, worker, planner models - /// that may have their own context limit environment variables. - pub fn new_with_context_env(model_name: String, context_env_var: Option<&str>) -> Self { - let context_limit = Self::get_context_limit_with_env_override(&model_name, context_env_var); - - let toolshim = std::env::var("GOOSE_TOOLSHIM") - .map(|val| val == "1" || val.to_lowercase() == "true") - .unwrap_or(false); - - let toolshim_model = std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok(); + pub fn new_with_context_env( + model_name: String, + context_env_var: Option<&str>, + ) -> Result { + let context_limit = Self::parse_context_limit(&model_name, context_env_var)?; + let temperature = Self::parse_temperature()?; + let toolshim = Self::parse_toolshim()?; + let toolshim_model = Self::parse_toolshim_model()?; - let temperature = std::env::var("GOOSE_TEMPERATURE") - .ok() - .and_then(|val| val.parse::().ok()); - - Self { + Ok(Self { model_name, context_limit, temperature, max_tokens: None, toolshim, toolshim_model, + }) + } + + fn parse_context_limit( + model_name: &str, + custom_env_var: Option<&str>, + ) -> Result, ConfigError> { + if let Some(env_var) = custom_env_var { + if let Ok(val) = std::env::var(env_var) { + return Self::validate_context_limit(&val, env_var).map(Some); + } } + if let Ok(val) = std::env::var("GOOSE_CONTEXT_LIMIT") { + return Self::validate_context_limit(&val, "GOOSE_CONTEXT_LIMIT").map(Some); + } + Ok(Self::get_model_specific_limit(model_name)) } - /// Get model-specific context limit based on model name - fn get_model_specific_limit(model_name: &str) -> Option { - for (pattern, &limit) in MODEL_SPECIFIC_LIMITS.iter() { - if model_name.contains(pattern) { - return Some(limit); + fn validate_context_limit(val: &str, env_var: &str) -> Result { + let limit = val.parse::().map_err(|_| { + ConfigError::InvalidValue( + env_var.to_string(), + val.to_string(), + "must be a positive integer".to_string(), + ) + })?; + + if limit < 4 * 1024 { + return Err(ConfigError::InvalidRange( + env_var.to_string(), + "must be greater than 4K".to_string(), + )); + } + + Ok(limit) + } + + fn parse_temperature() -> Result, ConfigError> { + if let Ok(val) = std::env::var("GOOSE_TEMPERATURE") { + let temp = val.parse::().map_err(|_| { + ConfigError::InvalidValue( + "GOOSE_TEMPERATURE".to_string(), + val.clone(), + "must be a valid number".to_string(), + ) + })?; + if temp < 0.0 { + return Err(ConfigError::InvalidRange( + "GOOSE_TEMPERATURE".to_string(), + val, + )); + } + Ok(Some(temp)) + } else { + Ok(None) + } + } + + fn parse_toolshim() -> Result { + if let Ok(val) = std::env::var("GOOSE_TOOLSHIM") { + match val.to_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Ok(true), + "0" | "false" | "no" | "off" => Ok(false), + _ => Err(ConfigError::InvalidValue( + "GOOSE_TOOLSHIM".to_string(), + val, + "must be one of: 1, true, yes, on, 0, false, no, off".to_string(), + )), } + } else { + Ok(false) } - None } - /// Get all model pattern matches and their limits + fn parse_toolshim_model() -> Result, ConfigError> { + match std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL") { + Ok(val) if val.trim().is_empty() => Err(ConfigError::InvalidValue( + "GOOSE_TOOLSHIM_OLLAMA_MODEL".to_string(), + val, + "cannot be empty if set".to_string(), + )), + Ok(val) => Ok(Some(val)), + Err(_) => Ok(None), + } + } + + fn get_model_specific_limit(model_name: &str) -> Option { + MODEL_SPECIFIC_LIMITS + .iter() + .find(|(pattern, _)| model_name.contains(pattern)) + .map(|(_, limit)| *limit) + } + pub fn get_all_model_limits() -> Vec { MODEL_SPECIFIC_LIMITS .iter() - .map(|(&pattern, &context_limit)| ModelLimitConfig { + .map(|(pattern, context_limit)| ModelLimitConfig { pattern: pattern.to_string(), - context_limit, + context_limit: *context_limit, }) .collect() } - /// Set an explicit context limit pub fn with_context_limit(mut self, limit: Option) -> Self { - // Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set - // if input is Some to allow passing through with_context_limit in - // configuration cases if limit.is_some() { self.context_limit = limit; } self } - /// Set the temperature pub fn with_temperature(mut self, temp: Option) -> Self { self.temperature = temp; self } - /// Set the max tokens pub fn with_max_tokens(mut self, tokens: Option) -> Self { self.max_tokens = tokens; self } - /// Set whether to interpret tool calls pub fn with_toolshim(mut self, toolshim: bool) -> Self { self.toolshim = toolshim; self } - /// Set the tool call interpreter model pub fn with_toolshim_model(mut self, model: Option) -> Self { self.toolshim_model = model; self } - /// Get the context_limit for the current model - /// If none are defined, use the DEFAULT_CONTEXT_LIMIT pub fn context_limit(&self) -> usize { self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT) } - /// Get context limit with environment variable override support - /// - /// The context limit is resolved with the following precedence: - /// 1. Custom environment variable (if specified) - /// 2. GOOSE_CONTEXT_LIMIT (default environment variable) - /// 3. Model-specific default based on model name - /// 4. Global default (128_000) - fn get_context_limit_with_env_override( - model_name: &str, - custom_env_var: Option<&str>, - ) -> Option { - // 1. Check custom environment variable first (e.g., GOOSE_LEAD_CONTEXT_LIMIT) - if let Some(env_var) = custom_env_var { - if let Ok(limit_str) = std::env::var(env_var) { - if let Ok(limit) = limit_str.parse::() { - return Some(limit); - } - } - } - - // 2. Check default context limit environment variable - if let Ok(limit_str) = std::env::var("GOOSE_CONTEXT_LIMIT") { - if let Ok(limit) = limit_str.parse::() { - return Some(limit); - } - } - - // 3. Fall back to model-specific defaults - Self::get_model_specific_limit(model_name) + pub fn new_or_fail(model_name: &str) -> ModelConfig { + ModelConfig::new(model_name) + .expect(&format!("Failed to create model config for {}", model_name)) } } #[cfg(test)] mod tests { use super::*; + use temp_env::with_var; #[test] fn test_model_config_context_limits() { - // Test explicit limit - let config = - ModelConfig::new("claude-3-opus".to_string()).with_context_limit(Some(150_000)); + let config = ModelConfig::new("claude-3-opus") + .unwrap() + .with_context_limit(Some(150_000)); assert_eq!(config.context_limit(), 150_000); - // Test model-specific defaults - let config = ModelConfig::new("claude-3-opus".to_string()); + let config = ModelConfig::new("claude-3-opus").unwrap(); assert_eq!(config.context_limit(), 200_000); - let config = ModelConfig::new("gpt-4-turbo".to_string()); + let config = ModelConfig::new("gpt-4-turbo").unwrap(); assert_eq!(config.context_limit(), 128_000); - // Test fallback to default - let config = ModelConfig::new("unknown-model".to_string()); + let config = ModelConfig::new("unknown-model").unwrap(); assert_eq!(config.context_limit(), DEFAULT_CONTEXT_LIMIT); } #[test] - fn test_model_config_settings() { - let config = ModelConfig::new("test-model".to_string()) - .with_temperature(Some(0.7)) - .with_max_tokens(Some(1000)) - .with_context_limit(Some(50_000)); - - assert_eq!(config.temperature, Some(0.7)); - assert_eq!(config.max_tokens, Some(1000)); - assert_eq!(config.context_limit, Some(50_000)); - } + fn test_invalid_context_limit() { + with_var("GOOSE_CONTEXT_LIMIT", Some("abc"), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); + if let Err(ConfigError::InvalidValue(var, val, msg)) = result { + assert_eq!(var, "GOOSE_CONTEXT_LIMIT"); + assert_eq!(val, "abc"); + assert!(msg.contains("positive integer")); + } + }); - #[test] - fn test_model_config_tool_interpretation() { - // Test without env vars - should be false - let config = ModelConfig::new("test-model".to_string()); - assert!(!config.toolshim); - - // Test with tool interpretation setting - let config = ModelConfig::new("test-model".to_string()).with_toolshim(true); - assert!(config.toolshim); - - // Test tool interpreter model - let config = ModelConfig::new("test-model".to_string()) - .with_toolshim_model(Some("mistral-nemo".to_string())); - assert_eq!(config.toolshim_model, Some("mistral-nemo".to_string())); + with_var("GOOSE_CONTEXT_LIMIT", Some("0"), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidRange(_, _) + )); + }); } #[test] - fn test_model_config_temp_env_var() { - use temp_env::with_var; - - with_var("GOOSE_TEMPERATURE", Some("0.128"), || { - let config = ModelConfig::new("test-model".to_string()); - assert_eq!(config.temperature, Some(0.128)); + fn test_invalid_temperature() { + with_var("GOOSE_TEMPERATURE", Some("hot"), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); }); - with_var("GOOSE_TEMPERATURE", Some("notanum"), || { - let config = ModelConfig::new("test-model".to_string()); - assert_eq!(config.temperature, None); + with_var("GOOSE_TEMPERATURE", Some("-1.0"), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); }); - - with_var("GOOSE_TEMPERATURE", Some(""), || { - let config = ModelConfig::new("test-model".to_string()); - assert_eq!(config.temperature, None); - }); - - let config = ModelConfig::new("test-model".to_string()); - assert_eq!(config.temperature, None); } #[test] - fn test_get_all_model_limits() { - let limits = ModelConfig::get_all_model_limits(); - assert!(!limits.is_empty()); - - // Test that we can find specific patterns - let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o"); - assert!(gpt4_limit.is_some()); - assert_eq!(gpt4_limit.unwrap().context_limit, 128_000); + fn test_invalid_toolshim() { + with_var("GOOSE_TOOLSHIM", Some("maybe"), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); + if let Err(ConfigError::InvalidValue(var, val, msg)) = result { + assert_eq!(var, "GOOSE_TOOLSHIM"); + assert_eq!(val, "maybe"); + assert!(msg.contains("must be one of")); + } + }); } #[test] - #[serial_test::serial] - fn test_model_config_context_limit_env_vars() { - use temp_env::with_vars; - - // Test default context limit environment variable - with_vars([("GOOSE_CONTEXT_LIMIT", Some("250000"))], || { - let config = ModelConfig::new("unknown-model".to_string()); - assert_eq!(config.context_limit(), 250_000); + fn test_empty_toolshim_model() { + with_var("GOOSE_TOOLSHIM_OLLAMA_MODEL", Some(""), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidValue(_, _, _) + )); }); - // Test custom context limit environment variable - with_vars( - [ - ("GOOSE_LEAD_CONTEXT_LIMIT", Some("300000")), - ("GOOSE_CONTEXT_LIMIT", Some("250000")), - ], - || { - let config = ModelConfig::new_with_context_env( - "unknown-model".to_string(), - Some("GOOSE_LEAD_CONTEXT_LIMIT"), - ); - // Should use the custom env var, not the default one - assert_eq!(config.context_limit(), 300_000); - }, - ); - - // Test fallback to model-specific when env var is invalid - with_vars([("GOOSE_CONTEXT_LIMIT", Some("invalid"))], || { - let config = ModelConfig::new("gpt-4o".to_string()); - assert_eq!(config.context_limit(), 128_000); // Should use model-specific default + with_var("GOOSE_TOOLSHIM_OLLAMA_MODEL", Some(" "), || { + let result = ModelConfig::new("test-model"); + assert!(result.is_err()); }); + } - // Test fallback to default when no env vars and unknown model - let config = ModelConfig::new("unknown-model".to_string()); - assert_eq!(config.context_limit(), DEFAULT_CONTEXT_LIMIT); + #[test] + fn test_valid_configurations() { + with_var("GOOSE_CONTEXT_LIMIT", Some("50000"), || { + with_var("GOOSE_TEMPERATURE", Some("0.7"), || { + with_var("GOOSE_TOOLSHIM", Some("true"), || { + with_var("GOOSE_TOOLSHIM_OLLAMA_MODEL", Some("llama3"), || { + let config = ModelConfig::new("test-model").unwrap(); + assert_eq!(config.context_limit(), 50_000); + assert_eq!(config.temperature, Some(0.7)); + assert!(config.toolshim); + assert_eq!(config.toolshim_model, Some("llama3".to_string())); + }); + }); + }); + }); } } diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index cb1b4d7483b8..4b870e30a262 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -311,8 +311,8 @@ mod tests { } fn create_mock_provider() -> Arc { - let mock_model_config = - ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into()); + let config = ModelConfig::new_or_fail("test-model"); + let mock_model_config = config.with_context_limit(200_000.into()); Arc::new(MockProvider { model_config: mock_model_config, }) diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 508fdd3cb9d3..233937940648 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -17,6 +17,7 @@ use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; use super::utils::{emit_debug_trace, get_model}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -46,12 +47,7 @@ pub struct AnthropicProvider { model: ModelConfig, } -impl Default for AnthropicProvider { - fn default() -> Self { - let model = ModelConfig::new(AnthropicProvider::metadata().default_model); - AnthropicProvider::from_env(model).expect("Failed to initialize Anthropic provider") - } -} +impl_provider_default!(AnthropicProvider); impl AnthropicProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 916215ddf34c..b4122ffb0380 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -11,6 +11,7 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -51,12 +52,7 @@ impl Serialize for AzureProvider { } } -impl Default for AzureProvider { - fn default() -> Self { - let model = ModelConfig::new(AzureProvider::metadata().default_model); - AzureProvider::from_env(model).expect("Failed to initialize Azure OpenAI provider") - } -} +impl_provider_default!(AzureProvider); impl AzureProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index ea108e3fb9f0..8a031931a75b 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -116,7 +116,7 @@ impl ProviderMetadata { .iter() .map(|&name| ModelInfo { name: name.to_string(), - context_limit: ModelConfig::new(name.to_string()).context_limit(), + context_limit: ModelConfig::new_or_fail(name).context_limit(), input_token_cost: None, output_token_cost: None, currency: None, @@ -401,7 +401,6 @@ mod tests { use std::collections::HashMap; use serde_json::json; - #[test] fn test_usage_creation() { let usage = Usage::new(Some(10), Some(20), Some(30)); diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 22d05fd6a18c..ac823dff6808 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,6 +1,12 @@ use std::collections::HashMap; use std::time::Duration; +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use super::errors::ProviderError; +use crate::impl_provider_default; +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::utils::emit_debug_trace; use anyhow::Result; use async_trait::async_trait; use aws_sdk_bedrockruntime::config::ProvideCredentials; @@ -10,12 +16,6 @@ use rmcp::model::Tool; use serde_json::Value; use tokio::time::sleep; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; -use super::errors::ProviderError; -use crate::message::Message; -use crate::model::ModelConfig; -use crate::providers::utils::emit_debug_trace; - // Import the migrated helper functions from providers/formats/bedrock.rs use super::formats::bedrock::{ from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config, @@ -70,12 +70,7 @@ impl BedrockProvider { } } -impl Default for BedrockProvider { - fn default() -> Self { - let model = ModelConfig::new(BedrockProvider::metadata().default_model); - BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider") - } -} +impl_provider_default!(BedrockProvider); #[async_trait] impl Provider for BedrockProvider { diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 03909c7a326d..833fd4547aa4 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -11,6 +11,7 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::config::Config; +use crate::impl_provider_default; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -26,12 +27,7 @@ pub struct ClaudeCodeProvider { model: ModelConfig, } -impl Default for ClaudeCodeProvider { - fn default() -> Self { - let model = ModelConfig::new(ClaudeCodeProvider::metadata().default_model); - ClaudeCodeProvider::from_env(model).expect("Failed to initialize Claude Code provider") - } -} +impl_provider_default!(ClaudeCodeProvider); impl ClaudeCodeProvider { pub fn from_env(model: ModelConfig) -> Result { @@ -520,6 +516,7 @@ impl Provider for ClaudeCodeProvider { #[cfg(test)] mod tests { + use super::ModelConfig; use super::*; #[test] @@ -547,7 +544,7 @@ mod tests { #[test] fn test_claude_code_invalid_model_no_fallback() { // Test that an invalid model is kept as-is (no fallback) - let invalid_model = ModelConfig::new("invalid-model".to_string()); + let invalid_model = ModelConfig::new_or_fail("invalid-model"); let provider = ClaudeCodeProvider::from_env(invalid_model).unwrap(); let config = provider.get_model_config(); @@ -557,7 +554,7 @@ mod tests { #[test] fn test_claude_code_valid_model() { // Test that a valid model is preserved - let valid_model = ModelConfig::new("sonnet".to_string()); + let valid_model = ModelConfig::new_or_fail("sonnet"); let provider = ClaudeCodeProvider::from_env(valid_model).unwrap(); let config = provider.get_model_config(); diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 4b88cf5f6c86..cb75a5585f29 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -17,6 +17,7 @@ use super::formats::databricks::{create_request, response_to_message}; use super::oauth; use super::utils::{get_model, ImageFormat}; use crate::config::ConfigError; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; @@ -141,12 +142,7 @@ pub struct DatabricksProvider { retry_config: RetryConfig, } -impl Default for DatabricksProvider { - fn default() -> Self { - let model = ModelConfig::new(DatabricksProvider::metadata().default_model); - DatabricksProvider::from_env(model).expect("Failed to initialize Databricks provider") - } -} +impl_provider_default!(DatabricksProvider); impl DatabricksProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index c5eab4cfa08d..b0cb696b71bf 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -98,17 +98,16 @@ fn create_lead_worker_from_env( .get_param::("GOOSE_LEAD_FALLBACK_TURNS") .unwrap_or(default_fallback_turns()); - // Create model configs with context limit environment variable support let lead_model_config = ModelConfig::new_with_context_env( lead_model_name.to_string(), Some("GOOSE_LEAD_CONTEXT_LIMIT"), - ); + )?; // For worker model, preserve the original context_limit from config (highest precedence) // while still allowing environment variable overrides let worker_model_config = { // Start with a clone of the original model to preserve user-specified settings - let mut worker_config = ModelConfig::new(default_model.model_name.clone()) + let mut worker_config = ModelConfig::new_or_fail(default_model.model_name.as_str()) .with_context_limit(default_model.context_limit) .with_temperature(default_model.temperature) .with_max_tokens(default_model.max_tokens) @@ -242,7 +241,8 @@ mod tests { env::set_var("GOOSE_LEAD_MODEL", "gpt-4o"); // This will try to create a lead/worker provider - let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + let gpt4mini_config = ModelConfig::new_or_fail("gpt-4o-mini"); + let result = create("openai", gpt4mini_config.clone()); // The creation might succeed or fail depending on API keys, but we can verify the logic path match result { @@ -261,7 +261,7 @@ mod tests { env::set_var("GOOSE_LEAD_PROVIDER", "anthropic"); env::set_var("GOOSE_LEAD_TURNS", "5"); - let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + let _result = create("openai", gpt4mini_config); // Similar validation as above - will fail due to missing API keys but confirms the logic // Restore env vars @@ -305,7 +305,7 @@ mod tests { env::set_var("GOOSE_LEAD_MODEL", "grok-3"); // This should use defaults for all other values - let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); // Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic) match result { @@ -324,7 +324,7 @@ mod tests { env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4"); env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3"); - let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + let _result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); // Should still attempt to create lead/worker provider with custom settings // Restore all env vars @@ -353,7 +353,7 @@ mod tests { env::remove_var("GOOSE_LEAD_FALLBACK_TURNS"); // This should try to create a regular provider - let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); // The creation might succeed or fail depending on API keys match result { @@ -368,7 +368,6 @@ mod tests { } } - // Restore env vars if let Some(val) = saved_lead { env::set_var("GOOSE_LEAD_MODEL", val); } @@ -410,7 +409,7 @@ mod tests { // Create a default model with explicit context_limit let default_model = - ModelConfig::new("gpt-3.5-turbo".to_string()).with_context_limit(Some(16_000)); + ModelConfig::new_or_fail("gpt-3.5-turbo").with_context_limit(Some(16_000)); // Test case 1: No environment variables - should preserve original context_limit let result = create_lead_worker_from_env("openai", &default_model, "gpt-4o"); diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index bc0f715ec1ce..0c451dea924d 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -911,15 +911,11 @@ mod tests { #[test] fn test_create_request_with_thinking() -> Result<()> { - // Save the original env var value if it exists let original_value = std::env::var("CLAUDE_THINKING_ENABLED").ok(); - - // Set the env var for this test std::env::set_var("CLAUDE_THINKING_ENABLED", "true"); - // Execute the test let result = (|| { - let model_config = ModelConfig::new("claude-3-7-sonnet-20250219".to_string()); + let model_config = ModelConfig::new_or_fail("claude-3-7-sonnet-20250219"); let system = "You are a helpful assistant."; let messages = vec![Message::user().with_text("Hello")]; let tools = vec![]; diff --git a/crates/goose/src/providers/formats/snowflake.rs b/crates/goose/src/providers/formats/snowflake.rs index 270973024bb1..50669fe3c08b 100644 --- a/crates/goose/src/providers/formats/snowflake.rs +++ b/crates/goose/src/providers/formats/snowflake.rs @@ -548,7 +548,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet"," fn test_create_request_format() -> Result<()> { use crate::model::ModelConfig; - let model_config = ModelConfig::new("claude-3-5-sonnet".to_string()); + let model_config = ModelConfig::new_or_fail("claude-3-5-sonnet"); let system = "You are a helpful assistant that can use tools to get information."; let messages = vec![Message::user().with_text("What is the stock price of Nvidia?")]; @@ -656,7 +656,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet"," fn test_create_request_excludes_tools_for_description() -> Result<()> { use crate::model::ModelConfig; - let model_config = ModelConfig::new("claude-3-5-sonnet".to_string()); + let model_config = ModelConfig::new_or_fail("claude-3-5-sonnet"); let system = "Reply with only a description in four words or less"; let messages = vec![Message::user().with_text("Test message")]; let tools = vec![Tool::new( diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 2544a1356953..e5d58f303275 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -18,6 +18,7 @@ use crate::providers::formats::gcpvertexai::{ ModelProvider, RequestContext, }; +use crate::impl_provider_default; use crate::providers::formats::gcpvertexai::GcpLocation::Iowa; use crate::providers::gcpauth::GcpAuth; use crate::providers::utils::emit_debug_trace; @@ -505,12 +506,7 @@ impl GcpVertexAIProvider { } } -impl Default for GcpVertexAIProvider { - fn default() -> Self { - let model = ModelConfig::new(Self::metadata().default_model); - Self::new(model).expect("Failed to initialize VertexAI provider") - } -} +impl_provider_default!(GcpVertexAIProvider); #[async_trait] impl Provider for GcpVertexAIProvider { @@ -711,7 +707,7 @@ mod tests { fn test_url_construction() { use url::Url; - let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string()); + let model_config = ModelConfig::new_or_fail("claude-3-5-sonnet-v2@20241022"); let context = RequestContext::new(&model_config.model_name).unwrap(); let api_model_id = context.model.to_string(); diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 387af4987d08..17ebb7ba2314 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -9,6 +9,7 @@ use tokio::process::Command; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; +use crate::impl_provider_default; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use rmcp::model::Role; @@ -25,12 +26,7 @@ pub struct GeminiCliProvider { model: ModelConfig, } -impl Default for GeminiCliProvider { - fn default() -> Self { - let model = ModelConfig::new(GeminiCliProvider::metadata().default_model); - GeminiCliProvider::from_env(model).expect("Failed to initialize Gemini CLI provider") - } -} +impl_provider_default!(GeminiCliProvider); impl GeminiCliProvider { pub fn from_env(model: ModelConfig) -> Result { @@ -376,7 +372,7 @@ mod tests { #[test] fn test_gemini_cli_invalid_model_no_fallback() { // Test that an invalid model is kept as-is (no fallback) - let invalid_model = ModelConfig::new("invalid-model".to_string()); + let invalid_model = ModelConfig::new_or_fail("invalid-model"); let provider = GeminiCliProvider::from_env(invalid_model).unwrap(); let config = provider.get_model_config(); @@ -386,7 +382,7 @@ mod tests { #[test] fn test_gemini_cli_valid_model() { // Test that a valid model is preserved - let valid_model = ModelConfig::new(GEMINI_CLI_DEFAULT_MODEL.to_string()); + let valid_model = ModelConfig::new_or_fail(GEMINI_CLI_DEFAULT_MODEL); let provider = GeminiCliProvider::from_env(valid_model).unwrap(); let config = provider.get_model_config(); diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 831209b28da8..899ce0cb4033 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -17,6 +17,7 @@ use super::formats::openai::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::config::{Config, ConfigError}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; @@ -115,12 +116,7 @@ pub struct GithubCopilotProvider { model: ModelConfig, } -impl Default for GithubCopilotProvider { - fn default() -> Self { - let model = ModelConfig::new(GithubCopilotProvider::metadata().default_model); - GithubCopilotProvider::from_env(model).expect("Failed to initialize GithubCopilot provider") - } -} +impl_provider_default!(GithubCopilotProvider); impl GithubCopilotProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 6967128af7fd..a499b08c51a4 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,4 +1,5 @@ use super::errors::ProviderError; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; @@ -55,12 +56,7 @@ pub struct GoogleProvider { model: ModelConfig, } -impl Default for GoogleProvider { - fn default() -> Self { - let model = ModelConfig::new(GoogleProvider::metadata().default_model); - GoogleProvider::from_env(model).expect("Failed to initialize Google provider") - } -} +impl_provider_default!(GoogleProvider); impl GoogleProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index e4c843702d90..3a9d48a787f3 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,4 +1,5 @@ use super::errors::ProviderError; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -32,12 +33,7 @@ pub struct GroqProvider { model: ModelConfig, } -impl Default for GroqProvider { - fn default() -> Self { - let model = ModelConfig::new(GroqProvider::metadata().default_model); - GroqProvider::from_env(model).expect("Failed to initialize Groq provider") - } -} +impl_provider_default!(GroqProvider); impl GroqProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 7909d5059e94..4ddea247048d 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -501,12 +501,12 @@ mod tests { async fn test_lead_worker_switching() { let lead_provider = Arc::new(MockProvider { name: "lead".to_string(), - model_config: ModelConfig::new("lead-model".to_string()), + model_config: ModelConfig::new_or_fail("lead-model"), }); let worker_provider = Arc::new(MockProvider { name: "worker".to_string(), - model_config: ModelConfig::new("worker-model".to_string()), + model_config: ModelConfig::new_or_fail("worker-model"), }); let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3)); @@ -541,13 +541,13 @@ mod tests { async fn test_technical_failure_retry() { let lead_provider = Arc::new(MockFailureProvider { name: "lead".to_string(), - model_config: ModelConfig::new("lead-model".to_string()), + model_config: ModelConfig::new_or_fail("lead-model"), should_fail: false, // Lead provider works }); let worker_provider = Arc::new(MockFailureProvider { name: "worker".to_string(), - model_config: ModelConfig::new("worker-model".to_string()), + model_config: ModelConfig::new_or_fail("worker-model"), should_fail: true, // Worker will fail }); @@ -583,13 +583,13 @@ mod tests { // For now, we'll test the fallback mode functionality directly let lead_provider = Arc::new(MockFailureProvider { name: "lead".to_string(), - model_config: ModelConfig::new("lead-model".to_string()), + model_config: ModelConfig::new_or_fail("lead-model"), should_fail: false, }); let worker_provider = Arc::new(MockFailureProvider { name: "worker".to_string(), - model_config: ModelConfig::new("worker-model".to_string()), + model_config: ModelConfig::new_or_fail("worker-model"), should_fail: false, }); diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 303e3aaeb542..6991d823dfbf 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -10,6 +10,7 @@ use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsag use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -28,12 +29,7 @@ pub struct LiteLLMProvider { custom_headers: Option>, } -impl Default for LiteLLMProvider { - fn default() -> Self { - let model = ModelConfig::new(LiteLLMProvider::metadata().default_model); - LiteLLMProvider::from_env(model).expect("Failed to initialize LiteLLM provider") - } -} +impl_provider_default!(LiteLLMProvider); impl LiteLLMProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 4001ee857315..5da98e19ab7f 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,6 +1,7 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{get_model, handle_response_openai_compat}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; @@ -28,12 +29,7 @@ pub struct OllamaProvider { model: ModelConfig, } -impl Default for OllamaProvider { - fn default() -> Self { - let model = ModelConfig::new(OllamaProvider::metadata().default_model); - OllamaProvider::from_env(model).expect("Failed to initialize Ollama provider") - } -} +impl_provider_default!(OllamaProvider); impl OllamaProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index ff6b65253c93..e57e9ae46286 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -17,6 +17,7 @@ use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::MessageStream; @@ -50,12 +51,7 @@ pub struct OpenAiProvider { custom_headers: Option>, } -impl Default for OpenAiProvider { - fn default() -> Self { - let model = ModelConfig::new(OpenAiProvider::metadata().default_model); - OpenAiProvider::from_env(model).expect("Failed to initialize OpenAI provider") - } -} +impl_provider_default!(OpenAiProvider); impl OpenAiProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 578add95dd2c..e0613b5e59ea 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -10,6 +10,7 @@ use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, is_google_model, }; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; @@ -38,12 +39,7 @@ pub struct OpenRouterProvider { model: ModelConfig, } -impl Default for OpenRouterProvider { - fn default() -> Self { - let model = ModelConfig::new(OpenRouterProvider::metadata().default_model); - OpenRouterProvider::from_env(model).expect("Failed to initialize OpenRouter provider") - } -} +impl_provider_default!(OpenRouterProvider); impl OpenRouterProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 9421d75b2e34..d2656aad6f21 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -13,6 +13,7 @@ use tokio::time::sleep; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; +use crate::impl_provider_default; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use chrono::Utc; @@ -254,12 +255,7 @@ impl SageMakerTgiProvider { } } -impl Default for SageMakerTgiProvider { - fn default() -> Self { - let model = ModelConfig::new(SageMakerTgiProvider::metadata().default_model); - SageMakerTgiProvider::from_env(model).expect("Failed to initialize SageMaker TGI provider") - } -} +impl_provider_default!(SageMakerTgiProvider); #[async_trait] impl Provider for SageMakerTgiProvider { diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index a19bbd11445d..3e52310ee54b 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -10,6 +10,7 @@ use super::errors::ProviderError; use super::formats::snowflake::{create_request, get_usage, response_to_message}; use super::utils::{get_model, ImageFormat}; use crate::config::ConfigError; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -42,12 +43,7 @@ pub struct SnowflakeProvider { image_format: ImageFormat, } -impl Default for SnowflakeProvider { - fn default() -> Self { - let model = ModelConfig::new(SnowflakeProvider::metadata().default_model); - SnowflakeProvider::from_env(model).expect("Failed to initialize Snowflake provider") - } -} +impl_provider_default!(SnowflakeProvider); impl SnowflakeProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index 899568a5433a..c25ad0022105 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -155,7 +155,7 @@ impl Provider for TestProvider { } fn get_model_config(&self) -> ModelConfig { - ModelConfig::new("test-model".to_string()) + ModelConfig::new_or_fail("test-model") } } @@ -223,7 +223,7 @@ mod tests { ); let mock = Arc::new(MockProvider { - model_config: ModelConfig::new("mock-model".to_string()), + model_config: ModelConfig::new_or_fail("mock-model"), response: "Hello, world!".to_string(), }); diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index 1eb43db2d23e..cae32e51baeb 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -153,7 +153,8 @@ impl OllamaInterpreter { let user_message = Message::user().with_text(format_instruction); messages.push(user_message); - let model_config = ModelConfig::new(model.to_string()); + let model_config = ModelConfig::new(model) + .map_err(|e| ProviderError::RequestFailed(format!("Model config error: {e}")))?; let mut payload = create_request( &model_config, diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 75cb31145b50..67acd4483f09 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -8,6 +8,7 @@ use std::time::Duration; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use crate::impl_provider_default; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use mcp_core::{ToolCall, ToolResult}; @@ -80,12 +81,7 @@ pub struct VeniceProvider { model: ModelConfig, } -impl Default for VeniceProvider { - fn default() -> Self { - let model = ModelConfig::new(VENICE_DEFAULT_MODEL.to_string()); - VeniceProvider::from_env(model).expect("Failed to initialize Venice provider") - } -} +impl_provider_default!(VeniceProvider); impl VeniceProvider { pub fn from_env(mut model: ModelConfig) -> Result { diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 6d24e087b3b6..3b8596a63284 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -1,4 +1,5 @@ use super::errors::ProviderError; +use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -45,12 +46,7 @@ pub struct XaiProvider { model: ModelConfig, } -impl Default for XaiProvider { - fn default() -> Self { - let model = ModelConfig::new(XaiProvider::metadata().default_model); - XaiProvider::from_env(model).expect("Failed to initialize xAI provider") - } -} +impl_provider_default!(XaiProvider); impl XaiProvider { pub fn from_env(model: ModelConfig) -> Result { diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 52baff428b51..09aac54a593b 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -432,7 +432,6 @@ impl RecipeBuilder { #[cfg(test)] mod tests { use super::*; - use std::fs; #[test] fn test_from_content_with_json() { diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 9d879ac58fe5..0c6a304b09ee 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1140,7 +1140,12 @@ async fn run_scheduled_job_internal( .to_string(), }), }; - let model_config = crate::model::ModelConfig::new(model_name.clone()); + let model_config = + crate::model::ModelConfig::new(model_name.as_str()).map_err(|e| JobExecutionError { + job_id: job.id.clone(), + error: format!("Model config error: {}", e), + })?; + agent_provider = create(&provider_name, model_config).map_err(|e| JobExecutionError { job_id: job.id.clone(), error: format!( @@ -1448,8 +1453,7 @@ mod tests { execution_mode: Some("background".to_string()), // Default for test }; - // Create the mock provider instance for the test - let mock_model_config = ModelConfig::new("test_model".to_string()); + let mock_model_config = ModelConfig::new_or_fail("test_model"); let mock_provider_instance = create_scheduler_test_mock_provider(mock_model_config); // Call run_scheduled_job_internal, passing the mock provider From ba816bfa5743566e708f35e9dd6ab89c4b4dfe33 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Wed, 30 Jul 2025 12:10:55 +1000 Subject: [PATCH 2/3] updating to main and fixing compile errors --- crates/goose-cli/src/commands/configure.rs | 6 +++--- crates/goose-cli/src/commands/web.rs | 2 +- crates/goose-cli/src/session/builder.rs | 8 ++++++-- crates/goose-cli/src/session/mod.rs | 2 +- crates/goose-server/src/routes/agent.rs | 2 +- crates/goose/src/model.rs | 2 +- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 3a8e792ee5e2..d91138ecac30 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -332,7 +332,7 @@ pub async fn configure_provider_dialog() -> Result> { let spin = spinner(); spin.start("Attempting to fetch supported models..."); let models_res = { - let temp_model_config = goose::model::ModelConfig::new(provider_meta.default_model.clone()); + let temp_model_config = goose::model::ModelConfig::new(&provider_meta.default_model)?; let temp_provider = create(provider_name, temp_model_config)?; temp_provider.fetch_supported_models_async().await }; @@ -373,7 +373,7 @@ pub async fn configure_provider_dialog() -> Result> { .map(|val| val == "1" || val.to_lowercase() == "true") .unwrap_or(false); - let model_config = goose::model::ModelConfig::new(model.clone()) + let model_config = goose::model::ModelConfig::new(&model)? .with_max_tokens(Some(50)) .with_toolshim(toolshim_enabled) .with_toolshim_model(std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok()); @@ -1231,7 +1231,7 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { let model: String = config .get_param("GOOSE_MODEL") .expect("No model configured. Please set model first"); - let model_config = goose::model::ModelConfig::new(model.clone()); + let model_config = goose::model::ModelConfig::new(&model)?; // Create the agent let agent = Agent::new(); diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index ba5206b4ba2e..287507258f74 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -99,7 +99,7 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> { } }; - let model_config = goose::model::ModelConfig::new(model.clone()); + let model_config = goose::model::ModelConfig::new(&model)?; // Create the agent let agent = Agent::new(); diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index d4b1b812a1ed..ae2adc9bf6b8 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -202,8 +202,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { let temperature = session_config.settings.as_ref().and_then(|s| s.temperature); - let model_config = - goose::model::ModelConfig::new(model_name.clone()).with_temperature(temperature); + let model_config = goose::model::ModelConfig::new(&model_name) + .unwrap_or_else(|e| { + output::render_error(&format!("Failed to create model configuration: {}", e)); + process::exit(1); + }) + .with_temperature(temperature); // Create the agent let agent: Agent = Agent::new(); diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 79d155681198..e9cf80a233b3 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1572,7 +1572,7 @@ fn get_reasoner() -> Result, anyhow::Error> { }; let model_config = - ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT")); + ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT"))?; let reasoner = create(&provider, model_config)?; Ok(reasoner) diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 1023b22dbbaf..bb416993b3cc 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -255,7 +255,7 @@ async fn update_agent_provider( .get_param("GOOSE_MODEL") .expect("Did not find a model on payload or in env to update provider with") }); - let model_config = ModelConfig::new(model); + let model_config = ModelConfig::new(&model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let new_provider = create(&payload.provider, model_config).unwrap(); agent .update_provider(new_provider) diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index b41fd5195a38..dad3f74c2bda 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -236,7 +236,7 @@ impl ModelConfig { pub fn new_or_fail(model_name: &str) -> ModelConfig { ModelConfig::new(model_name) - .expect(&format!("Failed to create model config for {}", model_name)) + .unwrap_or_else(|_| panic!("Failed to create model config for {}", model_name)) } } From 2ea4a1ba5df907ab9f9867c5657bfa1ece205732 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Wed, 30 Jul 2025 12:36:24 +1000 Subject: [PATCH 3/3] test time compile errors --- .../goose-cli/src/scenario_tests/scenario_runner.rs | 5 +---- crates/goose-server/src/routes/reply.rs | 2 +- crates/goose/tests/agent.rs | 11 ++++++----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 51018b983b0c..687d6e413a80 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -176,10 +176,7 @@ where let original_env = setup_environment(config)?; - let inner_provider = create( - &factory_name, - ModelConfig::new(config.model_name.to_string()), - )?; + let inner_provider = create(&factory_name, ModelConfig::new(&config.model_name)?)?; let test_provider = Arc::new(TestProvider::new_recording(inner_provider, &file_path)); ( diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 12cfae3a6d9a..d3ac7208e0a0 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -441,7 +441,7 @@ mod tests { #[tokio::test] async fn test_reply_endpoint() { - let mock_model_config = ModelConfig::new("test-model".to_string()); + let mock_model_config = ModelConfig::new("test-model").unwrap(); let mock_provider = Arc::new(MockProvider { model_config: mock_model_config, }); diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 016e9576ebea..8ef7854576b8 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -108,7 +108,8 @@ async fn run_truncate_test( model: &str, context_window: usize, ) -> Result<()> { - let model_config = ModelConfig::new(model.to_string()) + let model_config = ModelConfig::new(model) + .unwrap() .with_context_limit(Some(context_window)) .with_temperature(Some(0.0)); let provider = provider_type.create_provider(model_config)?; @@ -584,7 +585,7 @@ mod final_output_tool_tests { let agent = Agent::new(); - let model_config = ModelConfig::new("test-model".to_string()); + let model_config = ModelConfig::new("test-model").unwrap(); let mock_provider = Arc::new(MockProvider { model_config }); agent.update_provider(mock_provider).await?; @@ -704,7 +705,7 @@ mod final_output_tool_tests { let agent = Agent::new(); - let model_config = ModelConfig::new("test-model".to_string()); + let model_config = ModelConfig::new("test-model").unwrap(); let mock_provider = Arc::new(MockProvider { model_config }); agent.update_provider(mock_provider).await?; @@ -820,7 +821,7 @@ mod retry_tests { async fn test_retry_config_validation_integration() -> Result<()> { let agent = Agent::new(); - let model_config = ModelConfig::new("test-model".to_string()); + let model_config = ModelConfig::new("test-model").unwrap(); let mock_provider = Arc::new(MockRetryProvider { model_config, call_count: Arc::new(AtomicUsize::new(0)), @@ -986,7 +987,7 @@ mod max_turns_tests { } fn get_model_config(&self) -> ModelConfig { - ModelConfig::new("mock-model".to_string()) + ModelConfig::new("mock-model").unwrap() } fn metadata() -> ProviderMetadata {