diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 66480c6fb070..dd45ff598903 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -537,8 +537,9 @@ mod tests { goose::providers::base::ProviderMetadata::empty() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[rmcp::model::Tool], diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs index 062004650b69..4fce6d18be3d 100644 --- a/crates/goose/src/context_mgmt/auto_compact.rs +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -221,8 +221,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index 68947cfa98f7..360ef2a08d9c 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -44,7 +44,7 @@ pub async fn summarize_messages( // Send the request to the provider and fetch the response let (mut response, mut provider_usage) = provider - .complete(&system_prompt, &summarization_request, &[]) + .complete_fast(&system_prompt, &summarization_request, &[]) .await?; // Set role to user as it will be used in following conversation as user content @@ -87,8 +87,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index b606eaa3b043..228013f0b469 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -28,6 +28,7 @@ static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| { // anthropic - all 200k ("claude", 200_000), // google + ("gemini-1.5-flash", 1_000_000), ("gemini-1", 128_000), ("gemini-2", 1_000_000), ("gemma-3-27b", 128_000), @@ -72,6 +73,7 @@ pub struct ModelConfig { pub max_tokens: Option, pub toolshim: bool, pub toolshim_model: Option, + pub fast_model: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -89,7 +91,7 @@ impl ModelConfig { model_name: String, context_env_var: Option<&str>, ) -> Result { - let context_limit = Self::parse_context_limit(&model_name, context_env_var)?; + let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?; let temperature = Self::parse_temperature()?; let toolshim = Self::parse_toolshim()?; let toolshim_model = Self::parse_toolshim_model()?; @@ -101,13 +103,16 @@ impl ModelConfig { max_tokens: None, toolshim, toolshim_model, + fast_model: None, }) } fn parse_context_limit( model_name: &str, + fast_model: Option<&str>, custom_env_var: Option<&str>, ) -> Result, ConfigError> { + // First check if there's an explicit environment variable override if let Some(env_var) = custom_env_var { if let Ok(val) = std::env::var(env_var) { return Self::validate_context_limit(&val, env_var).map(Some); @@ -116,7 +121,24 @@ impl ModelConfig { if let Ok(val) = std::env::var("GOOSE_CONTEXT_LIMIT") { return Self::validate_context_limit(&val, "GOOSE_CONTEXT_LIMIT").map(Some); } - Ok(Self::get_model_specific_limit(model_name)) + + // Get the model's limit + let model_limit = Self::get_model_specific_limit(model_name); + + // If there's a fast_model, get its limit and use the minimum + if let Some(fast_model_name) = fast_model { + let fast_model_limit = Self::get_model_specific_limit(fast_model_name); + + // Return the minimum of both limits (if both exist) + match (model_limit, fast_model_limit) { + (Some(m), Some(f)) => Ok(Some(m.min(f))), + (Some(m), None) => Ok(Some(m)), + (None, Some(f)) => Ok(Some(f)), + (None, None) => Ok(None), + } + } else { + Ok(model_limit) + } } fn validate_context_limit(val: &str, env_var: &str) -> Result { @@ -231,8 +253,39 @@ impl ModelConfig { self } + pub fn with_fast(mut self, fast_model: String) -> Self { + self.fast_model = Some(fast_model); + self + } + + pub fn use_fast_model(&self) -> Self { + if let Some(fast_model) = &self.fast_model { + let mut config = self.clone(); + config.model_name = fast_model.clone(); + config + } else { + self.clone() + } + } + pub fn context_limit(&self) -> usize { - self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT) + // If we have an explicit context limit set, use it + if let Some(limit) = self.context_limit { + return limit; + } + + // Otherwise, get the model's default limit + let main_limit = + Self::get_model_specific_limit(&self.model_name).unwrap_or(DEFAULT_CONTEXT_LIMIT); + + // If we have a fast_model, also check its limit and use the minimum + if let Some(fast_model) = &self.fast_model { + let fast_limit = + Self::get_model_specific_limit(fast_model).unwrap_or(DEFAULT_CONTEXT_LIMIT); + main_limit.min(fast_limit) + } else { + main_limit + } } pub fn new_or_fail(model_name: &str) -> ModelConfig { diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 71438d38cb4c..922b8e21ea44 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -292,8 +292,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 952006ee25de..9db8d048e851 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -23,6 +23,7 @@ use crate::providers::retry::ProviderRetry; use rmcp::model::Tool; const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-0"; +const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-3-7-sonnet-latest"; const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ "claude-sonnet-4-0", "claude-sonnet-4-20250514", @@ -50,6 +51,8 @@ impl_provider_default!(AnthropicProvider); impl AnthropicProvider { pub fn from_env(model: ModelConfig) -> Result { + let model = model.with_fast(ANTHROPIC_DEFAULT_FAST_MODEL.to_string()); + let config = crate::config::Config::global(); let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?; let host: String = config @@ -179,16 +182,17 @@ impl Provider for AnthropicProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools)?; + let payload = create_request(model_config, system, messages, tools)?; let response = self .with_retry(|| async { self.post(&payload).await }) @@ -201,9 +205,9 @@ impl Provider for AnthropicProvider { tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", usage.input_tokens, usage.output_tokens, usage.total_tokens); - let model = get_model(&json_response); + let response_model = get_model(&json_response); emit_debug_trace(&self.model, &payload, &json_response, &usage); - let provider_usage = ProviderUsage::new(model, usage); + let provider_usage = ProviderUsage::new(response_model, usage); tracing::debug!( "🔍 Anthropic non-streaming returning ProviderUsage: {:?}", provider_usage @@ -271,7 +275,7 @@ impl Provider for AnthropicProvider { let stream = response.bytes_stream().map_err(io::Error::other); - let model_config = self.model.clone(); + let model = self.model.clone(); Ok(Box::pin(try_stream! { let stream_reader = StreamReader::new(stream); let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from); @@ -280,7 +284,7 @@ impl Provider for AnthropicProvider { pin!(message_stream); while let Some(message) = futures::StreamExt::next(&mut message_stream).await { let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + emit_debug_trace(&model, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); yield (message, usage); } })) diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index f40993d67657..12bf93f20e3d 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -135,16 +135,17 @@ impl Provider for AzureProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?; let response = self .with_retry(|| async { let payload_clone = payload.clone(); @@ -157,8 +158,8 @@ impl Provider for AzureProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 60623abb3a3e..bed31d59909c 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -317,26 +317,41 @@ pub trait Provider: Send + Sync { where Self: Sized; - /// Generate the next message using the configured model and other parameters - /// - /// # Arguments - /// * `system` - The system prompt that guides the model's behavior - /// * `messages` - The conversation history as a sequence of messages - /// * `tools` - Optional list of tools the model can use - /// - /// # Returns - /// A tuple containing the model's response message and provider usage statistics - /// - /// # Errors - /// ProviderError - /// - It's important to raise ContextLengthExceeded correctly since agent handles it - async fn complete( + // Internal implementation of complete, used by complete_fast and complete + // Providers should override this to implement their actual completion logic + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError>; + // Default implementation: use the provider's configured model + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_config = self.get_model_config(); + self.complete_with_model(&model_config, system, messages, tools) + .await + } + + // Check if a fast model is configured, otherwise fall back to regular model + async fn complete_fast( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_config = self.get_model_config(); + let fast_config = model_config.use_fast_model(); + self.complete_with_model(&fast_config, system, messages, tools) + .await + } + /// Get the model config from the provider fn get_model_config(&self) -> ModelConfig; @@ -418,7 +433,7 @@ pub trait Provider: Send + Sync { let prompt = self.create_session_name_prompt(&context); let message = Message::user().with_text(&prompt); let result = self - .complete( + .complete_fast( "Reply with only a description in four words or less", &[message], &[], diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 7579a7de7141..a2e04dbbf539 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -152,16 +152,17 @@ impl Provider for BedrockProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let model_name = &self.model.model_name; + let model_name = model_config.model_name.clone(); let (bedrock_message, bedrock_usage) = self .with_retry(|| self.converse(system, messages, tools)) diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 3185a7961bc2..9ef2950c272c 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -474,11 +474,12 @@ impl Provider for ClaudeCodeProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -495,7 +496,7 @@ impl Provider for ClaudeCodeProvider { // Create a dummy payload for debug tracing let payload = json!({ "command": self.command, - "model": self.model.model_name, + "model": model_config.model_name, "system": system, "messages": messages.len() }); @@ -505,11 +506,11 @@ impl Provider for ClaudeCodeProvider { "usage": usage }); - emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(model_config, &payload, &response, &usage); Ok(( message, - ProviderUsage::new(self.model.model_name.clone(), usage), + ProviderUsage::new(model_config.model_name.clone(), usage), )) } } diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 432093df0b51..bbce315f648b 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -407,11 +407,12 @@ impl Provider for CursorAgentProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -428,7 +429,7 @@ impl Provider for CursorAgentProvider { // Create a dummy payload for debug tracing let payload = json!({ "command": self.command, - "model": self.model.model_name, + "model": model_config.model_name, "system": system, "messages": messages.len() }); @@ -438,11 +439,11 @@ impl Provider for CursorAgentProvider { "usage": usage }); - emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(model_config, &payload, &response, &usage); Ok(( message, - ProviderUsage::new(self.model.model_name.clone(), usage), + ProviderUsage::new(model_config.model_name.clone(), usage), )) } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index c635fe589470..0172eba58df8 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -37,6 +37,7 @@ const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; const DEFAULT_TIMEOUT_SECS: u64 = 600; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet"; +const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash"; pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ "databricks-meta-llama-3-3-70b-instruct", "databricks-meta-llama-3-1-405b-instruct", @@ -137,13 +138,41 @@ impl DatabricksProvider { let api_client = ApiClient::with_timeout(host, auth_method, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?; - Ok(Self { + // Create the provider without the fast model first + let mut provider = Self { api_client, auth, - model, + model: model.clone(), image_format: ImageFormat::OpenAi, retry_config, - }) + }; + + // Check if the default fast model exists in the workspace + let model_with_fast = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + if let Ok(Some(models)) = provider.fetch_supported_models().await { + if models.contains(&DATABRICKS_DEFAULT_FAST_MODEL.to_string()) { + tracing::debug!( + "Found {} in Databricks workspace, setting as fast model", + DATABRICKS_DEFAULT_FAST_MODEL + ); + model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string()) + } else { + tracing::debug!( + "{} not found in Databricks workspace, not setting fast model", + DATABRICKS_DEFAULT_FAST_MODEL + ); + model + } + } else { + tracing::debug!("Could not fetch Databricks models, not setting fast model"); + model + } + }) + }); + + provider.model = model_with_fast; + Ok(provider) } fn load_retry_config(config: &crate::config::Config) -> RetryConfig { @@ -195,17 +224,18 @@ impl DatabricksProvider { }) } - fn get_endpoint_path(&self, is_embedding: bool) -> String { + fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String { if is_embedding { "serving-endpoints/text-embedding-3-small/invocations".to_string() } else { - format!("serving-endpoints/{}/invocations", self.model.model_name) + format!("serving-endpoints/{}/invocations", model_name) } } - async fn post(&self, payload: Value) -> Result { + async fn post(&self, payload: Value, model_name: Option<&str>) -> Result { let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); - let path = self.get_endpoint_path(is_embedding); + let model_to_use = model_name.unwrap_or(&self.model.model_name); + let path = self.get_endpoint_path(model_to_use, is_embedding); let response = self.api_client.response_post(&path, &payload).await?; handle_response_openai_compat(response).await @@ -238,32 +268,36 @@ impl Provider for DatabricksProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?; + let mut payload = + create_request(model_config, system, messages, tools, &self.image_format)?; payload .as_object_mut() .expect("payload should have model key") .remove("model"); - let response = self.with_retry(|| self.post(payload.clone())).await?; + let response = self + .with_retry(|| self.post(payload.clone(), Some(&model_config.model_name))) + .await?; let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); + let response_model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + Ok((message, ProviderUsage::new(response_model, usage))) } async fn stream( @@ -272,7 +306,10 @@ impl Provider for DatabricksProvider { messages: &[Message], tools: &[Tool], ) -> Result { - let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?; + let model_config = self.model.clone(); + + let mut payload = + create_request(&model_config, system, messages, tools, &self.image_format)?; payload .as_object_mut() .expect("payload should have model key") @@ -283,7 +320,7 @@ impl Provider for DatabricksProvider { .unwrap() .insert("stream".to_string(), Value::Bool(true)); - let path = self.get_endpoint_path(false); + let path = self.get_endpoint_path(&model_config.model_name, false); let response = self .with_retry(|| async { let resp = self.api_client.response_post(&path, &payload).await?; @@ -299,8 +336,8 @@ impl Provider for DatabricksProvider { .await?; let stream = response.bytes_stream().map_err(io::Error::other); - let model_config = self.model.clone(); + let model = self.model.clone(); Ok(Box::pin(try_stream! { let stream_reader = StreamReader::new(stream); let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); @@ -309,7 +346,7 @@ impl Provider for DatabricksProvider { pin!(message_stream); while let Some(message) = message_stream.next().await { let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + super::utils::emit_debug_trace(&model, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); yield (message, usage); } })) @@ -408,7 +445,7 @@ impl EmbeddingCapable for DatabricksProvider { "input": texts, }); - let response = self.with_retry(|| self.post(request.clone())).await?; + let response = self.with_retry(|| self.post(request.clone(), None)).await?; let embeddings = response["data"] .as_array() diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index c06c8e954f0e..12f3aec90551 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -202,8 +202,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 06f052b59c2a..b7472ca23464 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1045,6 +1045,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); @@ -1076,6 +1077,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); @@ -1108,6 +1110,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index e5bdfdb08dfc..3ff4b712a867 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1077,6 +1077,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); @@ -1108,6 +1109,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); @@ -1140,6 +1142,7 @@ mod tests { max_tokens: Some(1024), toolshim: false, toolshim_model: None, + fast_model: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 969d7146d7e2..609d77ab7eb6 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -512,23 +512,24 @@ impl Provider for GcpVertexAIProvider { /// * `messages` - Array of previous messages in the conversation /// * `tools` - Array of available tools for the model #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { // Create request and context - let (request, context) = create_request(&self.model, system, messages, tools)?; + let (request, context) = create_request(model_config, system, messages, tools)?; // Send request and process response let response = self.post(&request, &context).await?; let usage = get_usage(&response, &context)?; - emit_debug_trace(&self.model, &request, &response, &usage); + emit_debug_trace(model_config, &request, &response, &usage); // Convert response to message let message = response_to_message(response, context)?; diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index fcfc0f75c369..4b14270244fd 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -319,11 +319,12 @@ impl Provider for GeminiCliProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -350,7 +351,7 @@ impl Provider for GeminiCliProvider { "usage": usage }); - emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(model_config, &payload, &response, &usage); Ok(( message, diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index fc1e7fb640dd..720d8089b717 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -401,16 +401,17 @@ impl Provider for GithubCopilotProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?; // Make request with retry let response = self @@ -426,9 +427,9 @@ impl Provider for GithubCopilotProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } /// Fetch supported models from GitHub Copliot; returns Err on failure, Ok(None) if not present diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index fa262f403c3a..4bf843088e34 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -14,6 +14,7 @@ use serde_json::Value; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-flash"; +pub const GOOGLE_DEFAULT_FAST_MODEL: &str = "gemini-1.5-flash"; pub const GOOGLE_KNOWN_MODELS: &[&str] = &[ // Gemini 2.5 models (latest generation) "gemini-2.5-pro", @@ -55,6 +56,8 @@ impl_provider_default!(GoogleProvider); impl GoogleProvider { pub fn from_env(model: ModelConfig) -> Result { + let model = model.with_fast(GOOGLE_DEFAULT_FAST_MODEL.to_string()); + let config = crate::config::Config::global(); let api_key: String = config.get_secret("GOOGLE_API_KEY")?; let host: String = config @@ -72,8 +75,8 @@ impl GoogleProvider { Ok(Self { api_client, model }) } - async fn post(&self, payload: &Value) -> Result { - let path = format!("v1beta/models/{}:generateContent", self.model.model_name); + async fn post(&self, model_name: &str, payload: &Value) -> Result { + let path = format!("v1beta/models/{}:generateContent", model_name); let response = self.api_client.response_post(&path, payload).await?; handle_response_google_compat(response).await } @@ -101,34 +104,35 @@ impl Provider for GoogleProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools)?; + let payload = create_request(model_config, system, messages, tools)?; // Make request let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(&model_config.model_name, &payload_clone).await }) .await?; // Parse response let message = response_to_message(unescape_json_values(&response))?; let usage = get_usage(&response)?; - let model = match response.get("modelVersion") { + let response_model = match response.get("modelVersion") { Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), - None => self.model.model_name.clone(), + None => model_config.model_name.clone(), }; - emit_debug_trace(&self.model, &payload, &response, &usage); - let provider_usage = ProviderUsage::new(model, usage); + emit_debug_trace(model_config, &payload, &response, &usage); + let provider_usage = ProviderUsage::new(response_model, usage); Ok((message, provider_usage)) } diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index acb51d1fc75d..e70a9d36aa58 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -77,17 +77,18 @@ impl Provider for GroqProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request( - &self.model, + model_config, system, messages, tools, @@ -101,9 +102,9 @@ impl Provider for GroqProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + super::utils::emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 18564a2d1261..07638b2198bc 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -326,8 +326,9 @@ impl Provider for LeadWorkerProvider { self.lead_provider.get_model_config() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -475,8 +476,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], @@ -635,8 +637,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 8911341b4d42..d08e6b5c286b 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -161,14 +161,15 @@ impl Provider for LiteLLMProvider { } #[tracing::instrument(skip_all, name = "provider_complete")] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let mut payload = super::formats::openai::create_request( - &self.model, + model_config, system, messages, tools, @@ -188,9 +189,9 @@ impl Provider for LiteLLMProvider { let message = super::formats::openai::response_to_message(&response)?; let usage = super::formats::openai::get_usage(&response); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } fn supports_embeddings(&self) -> bool { diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 84b6a5e5e184..0cf35cc408c8 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -165,11 +165,12 @@ impl Provider for OllamaProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -197,9 +198,9 @@ impl Provider for OllamaProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + super::utils::emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } /// Generate a session name based on the conversation history diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index d49dcfeffbdf..7a1398cfbfa0 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -29,6 +29,7 @@ use crate::providers::formats::openai::response_to_streaming_message; use rmcp::model::Tool; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; +pub const OPEN_AI_DEFAULT_FAST_MODEL: &str = "gpt-4o-mini"; pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[ ("gpt-4o", 128_000), ("gpt-4o-mini", 128_000), @@ -59,6 +60,8 @@ impl_provider_default!(OpenAiProvider); impl OpenAiProvider { pub fn from_env(model: ModelConfig) -> Result { + let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string()); + let config = crate::config::Config::global(); let api_key: String = config.get_secret("OPENAI_API_KEY")?; let host: String = config @@ -193,16 +196,17 @@ impl Provider for OpenAiProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?; let json_response = self.post(&payload).await?; diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 00fa77bdea7a..49b9cea164e2 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -238,11 +238,12 @@ impl Provider for OpenRouterProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -264,9 +265,9 @@ impl Provider for OpenRouterProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } /// Fetch supported models from OpenRouter API (only models with tool support) diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 90a2498d0977..0c65e32d75cc 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -280,16 +280,17 @@ impl Provider for SageMakerTgiProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let model_name = &self.model.model_name; + let model_name = &model_config.model_name; let request_payload = self.create_tgi_request(system, messages).map_err(|e| { ProviderError::RequestFailed(format!("Failed to create request: {}", e)) diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 8e8ea663b5ae..5b9344a29d9e 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -299,16 +299,17 @@ impl Provider for SnowflakeProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools)?; + let payload = create_request(model_config, system, messages, tools)?; let response = self .with_retry(|| async { @@ -320,9 +321,9 @@ impl Provider for SnowflakeProvider { // Parse response let message = response_to_message(&response)?; let usage = get_usage(&response)?; - let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); + let response_model = get_model(&response); + super::utils::emit_debug_trace(model_config, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + Ok((message, ProviderUsage::new(response_model, usage))) } } diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index eca2c87627b1..98c4f22739d9 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -112,8 +112,9 @@ impl Provider for TestProvider { ) } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], @@ -188,8 +189,9 @@ mod tests { ) } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index ce865f4cc519..2aacfc9e0b9c 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,4 +1,4 @@ -use anyhow::{Error, Result}; +use anyhow::Result; use async_trait::async_trait; use serde_json::Value; @@ -113,23 +113,6 @@ impl TetrateProvider { } } -fn create_request_based_on_model( - provider: &TetrateProvider, - system: &str, - messages: &[Message], - tools: &[Tool], -) -> anyhow::Result { - let payload = create_request( - &provider.model, - system, - messages, - tools, - &super::utils::ImageFormat::OpenAi, - )?; - - Ok(payload) -} - #[async_trait] impl Provider for TetrateProvider { fn metadata() -> ProviderMetadata { @@ -157,17 +140,24 @@ impl Provider for TetrateProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - // Create the base payload - let payload = create_request_based_on_model(self, system, messages, tools)?; + // Create the base payload using the provided model_config + let payload = create_request( + model_config, + system, + messages, + tools, + &super::utils::ImageFormat::OpenAi, + )?; // Make request let response = self @@ -184,7 +174,7 @@ impl Provider for TetrateProvider { Usage::default() }); let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(model_config, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 185587c6df6c..9c075bec9c06 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -246,12 +246,13 @@ impl Provider for VeniceProvider { } #[tracing::instrument( - skip(_system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, - _system: &str, + model_config: &ModelConfig, + system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { @@ -259,10 +260,10 @@ impl Provider for VeniceProvider { let mut formatted_messages = Vec::new(); // Add the system message if present - if !_system.is_empty() { + if !system.is_empty() { formatted_messages.push(json!({ "role": "system", - "content": _system + "content": system })); } @@ -391,7 +392,7 @@ impl Provider for VeniceProvider { // Build Venice-specific payload let mut payload = json!({ - "model": strip_flags(&self.model.model_name), + "model": strip_flags(&model_config.model_name), "messages": formatted_messages, "stream": false, "temperature": 0.7, @@ -470,7 +471,7 @@ impl Provider for VeniceProvider { return Ok(( message, ProviderUsage::new( - strip_flags(&self.model.model_name).to_string(), + strip_flags(&model_config.model_name).to_string(), Usage::default(), ), )); diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 7b2aed5f15c8..65ecccbd7f36 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -93,17 +93,18 @@ impl Provider for XaiProvider { } #[tracing::instrument( - skip(self, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete( + async fn complete_with_model( &self, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request( - &self.model, + model_config, system, messages, tools, @@ -117,8 +118,8 @@ impl Provider for XaiProvider { tracing::debug!("Failed to get usage data"); Usage::default() }); - let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let response_model = get_model(&response); + super::utils::emit_debug_trace(model_config, &payload, &response, &usage); + Ok((message, ProviderUsage::new(response_model, usage))) } } diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 0116e203fb30..c6dd933cb7ac 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1390,8 +1390,9 @@ mod tests { self.model_config.clone() } - async fn complete( + async fn complete_with_model( &self, + _model_config: &ModelConfig, _system: &str, _messages: &[Message], _tools: &[Tool], diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index b4659332122b..9c7e84838c0a 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -592,6 +592,16 @@ mod final_output_tool_tests { ProviderUsage::new("mock".to_string(), Usage::default()), )) } + + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + self.complete(system, messages, tools).await + } } let agent = Agent::new(); @@ -713,6 +723,16 @@ mod final_output_tool_tests { ) -> Result<(Message, ProviderUsage), ProviderError> { Err(ProviderError::NotImplemented("Not implemented".to_string())) } + + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + self.complete(system, messages, tools).await + } } let agent = Agent::new(); @@ -829,6 +849,16 @@ mod retry_tests { )) } } + + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + self.complete(system, messages, tools).await + } } #[tokio::test] @@ -1002,6 +1032,16 @@ mod max_turns_tests { Ok((message, usage)) } + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + system_prompt: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + self.complete(system_prompt, messages, tools).await + } + fn get_model_config(&self) -> ModelConfig { ModelConfig::new("mock-model").unwrap() }