From 0ced45550048a7d254f0c155b1f64a11d79eda2e Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Mon, 27 Oct 2025 14:40:10 -0400 Subject: [PATCH 1/3] Enable runtime access to provider metadata (prep for provider/model config persistence) --- crates/goose/src/agents/reply_parts.rs | 4 ++++ crates/goose/src/providers/anthropic.rs | 9 +++++++++ crates/goose/src/providers/azure.rs | 7 +++++++ crates/goose/src/providers/base.rs | 5 ++++- crates/goose/src/providers/bedrock.rs | 8 ++++++++ crates/goose/src/providers/claude_code.rs | 8 ++++++++ crates/goose/src/providers/cursor_agent.rs | 8 ++++++++ crates/goose/src/providers/databricks.rs | 9 +++++++++ crates/goose/src/providers/gcpvertexai.rs | 8 ++++++++ crates/goose/src/providers/gemini_cli.rs | 8 ++++++++ crates/goose/src/providers/githubcopilot.rs | 8 ++++++++ crates/goose/src/providers/google.rs | 13 ++++++++++++- crates/goose/src/providers/lead_worker.rs | 13 +++++++++++++ crates/goose/src/providers/litellm.rs | 8 ++++++++ crates/goose/src/providers/ollama.rs | 9 +++++++++ crates/goose/src/providers/openai.rs | 9 +++++++++ crates/goose/src/providers/openrouter.rs | 13 ++++++++++++- crates/goose/src/providers/sagemaker_tgi.rs | 8 ++++++++ crates/goose/src/providers/snowflake.rs | 8 ++++++++ crates/goose/src/providers/testprovider.rs | 11 +++++++++++ crates/goose/src/providers/tetrate.rs | 8 ++++++++ crates/goose/src/providers/venice.rs | 8 ++++++++ crates/goose/src/providers/xai.rs | 13 ++++++++++++- crates/goose/src/scheduler.rs | 4 ++++ crates/goose/tests/agent.rs | 16 ++++++++++++++++ 25 files changed, 219 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 7d54e60fb045..7d3cd1f74c26 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -309,6 +309,10 @@ mod tests { crate::providers::base::ProviderMetadata::empty() } + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(crate::providers::base::ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 0558e3fe985b..0f0c3f8e19b7 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -5,6 +5,7 @@ use futures::TryStreamExt; use reqwest::StatusCode; use serde_json::Value; use std::io; +use std::sync::Arc; use tokio::pin; use tokio_util::io::StreamReader; @@ -43,6 +44,8 @@ pub struct AnthropicProvider { api_client: ApiClient, model: ModelConfig, supports_streaming: bool, + #[serde(skip)] + metadata: Arc, } impl AnthropicProvider { @@ -67,6 +70,7 @@ impl AnthropicProvider { api_client, model, supports_streaming: true, + metadata: Arc::new(Self::metadata()), }) } @@ -91,6 +95,7 @@ impl AnthropicProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), + metadata: Arc::new(Self::metadata()), }) } @@ -176,6 +181,10 @@ impl Provider for AnthropicProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index cd7f51a42af6..2d49646afd20 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -2,6 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use serde::Serialize; use serde_json::Value; +use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::azureauth::{AuthError, AzureAuth}; @@ -27,6 +28,7 @@ pub struct AzureProvider { deployment_name: String, api_version: String, model: ModelConfig, + metadata: Arc, } impl Serialize for AzureProvider { @@ -94,6 +96,7 @@ impl AzureProvider { deployment_name, api_version, model, + metadata: Arc::new(Self::metadata()), }) } @@ -128,6 +131,10 @@ impl Provider for AzureProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index f8dfadee3757..bf0c3415938e 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -14,7 +14,7 @@ use utoipa::ToSchema; use once_cell::sync::Lazy; use std::ops::{Add, AddAssign}; use std::pin::Pin; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; /// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias pub static CURRENT_MODEL: Lazy>> = Lazy::new(|| Mutex::new(None)); @@ -325,6 +325,9 @@ pub trait Provider: Send + Sync { where Self: Sized; + /// Get the metadata for this provider instance + fn get_metadata(&self) -> Arc; + // Internal implementation of complete, used by complete_fast and complete // Providers should override this to implement their actual completion logic async fn complete_with_model( diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index b5f720692a0a..7a17ce95c23b 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; @@ -42,6 +43,8 @@ pub struct BedrockProvider { model: ModelConfig, #[serde(skip)] retry_config: RetryConfig, + #[serde(skip)] + metadata: Arc, } impl BedrockProvider { @@ -78,6 +81,7 @@ impl BedrockProvider { client, model, retry_config, + metadata: Arc::new(Self::metadata()), }) } @@ -184,6 +188,10 @@ impl Provider for BedrockProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn retry_config(&self) -> RetryConfig { self.retry_config.clone() } diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index d27be5ce273f..8dd71f3334ca 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -4,6 +4,7 @@ use rmcp::model::Role; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -24,6 +25,8 @@ pub const CLAUDE_CODE_DOC_URL: &str = "https://claude.ai/cli"; pub struct ClaudeCodeProvider { command: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl ClaudeCodeProvider { @@ -42,6 +45,7 @@ impl ClaudeCodeProvider { Ok(Self { command: resolved_command, model, + metadata: Arc::new(Self::metadata()), }) } @@ -463,6 +467,10 @@ impl Provider for ClaudeCodeProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { // Return the model config with appropriate context limit for Claude models self.model.clone() diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index b5f4e0b98552..6f467dcdc0c5 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -4,6 +4,7 @@ use rmcp::model::Role; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -23,6 +24,8 @@ pub const CURSOR_AGENT_DOC_URL: &str = "https://docs.cursor.com/en/cli/overview" pub struct CursorAgentProvider { command: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl CursorAgentProvider { @@ -41,6 +44,7 @@ impl CursorAgentProvider { Ok(Self { command: resolved_command, model, + metadata: Arc::new(Self::metadata()), }) } @@ -395,6 +399,10 @@ impl Provider for CursorAgentProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { // Return the model config with appropriate context limit for Cursor models self.model.clone() diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 44dc59971dd7..9ed485f363fd 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -5,6 +5,7 @@ use futures::TryStreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io; +use std::sync::Arc; use std::time::Duration; use tokio::pin; use tokio_util::io::StreamReader; @@ -106,6 +107,8 @@ pub struct DatabricksProvider { image_format: ImageFormat, #[serde(skip)] retry_config: RetryConfig, + #[serde(skip)] + metadata: Arc, } impl DatabricksProvider { @@ -146,6 +149,7 @@ impl DatabricksProvider { model: model.clone(), image_format: ImageFormat::OpenAi, retry_config, + metadata: Arc::new(Self::metadata()), }; // Check if the default fast model exists in the workspace @@ -222,6 +226,7 @@ impl DatabricksProvider { model, image_format: ImageFormat::OpenAi, retry_config: RetryConfig::default(), + metadata: Arc::new(Self::metadata()), }) } @@ -260,6 +265,10 @@ impl Provider for DatabricksProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn retry_config(&self) -> RetryConfig { self.retry_config.clone() } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index e2ac26a31af3..a56440662621 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -76,6 +77,8 @@ pub struct GcpVertexAIProvider { /// Retry configuration for handling rate limit errors #[serde(skip)] retry_config: RetryConfig, + #[serde(skip)] + metadata: Arc, } impl GcpVertexAIProvider { @@ -109,6 +112,7 @@ impl GcpVertexAIProvider { location, model, retry_config, + metadata: Arc::new(Self::metadata()), }) } @@ -494,6 +498,10 @@ impl Provider for GcpVertexAIProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + /// Completes a model interaction by sending a request and processing the response. /// /// # Arguments diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 417774a70002..307dfb2d6b85 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use serde_json::json; use std::path::PathBuf; use std::process::Stdio; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -24,6 +25,8 @@ pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs"; pub struct GeminiCliProvider { command: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl GeminiCliProvider { @@ -42,6 +45,7 @@ impl GeminiCliProvider { Ok(Self { command: resolved_command, model, + metadata: Arc::new(Self::metadata()), }) } @@ -311,6 +315,10 @@ impl Provider for GeminiCliProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { // Return the model config with appropriate context limit for Gemini models self.model.clone() diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 4a822ee7abbe..2418c82ec002 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -9,6 +9,7 @@ use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -113,6 +114,8 @@ pub struct GithubCopilotProvider { #[serde(skip)] mu: tokio::sync::Mutex>>, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl GithubCopilotProvider { @@ -127,6 +130,7 @@ impl GithubCopilotProvider { cache, mu, model, + metadata: Arc::new(Self::metadata()), }) } @@ -392,6 +396,10 @@ impl Provider for GithubCopilotProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index d6d7b30dd2ac..54c4305ab1b1 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,6 +3,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{handle_response_google_compat, unescape_json_values, RequestLog}; use crate::conversation::message::Message; +use std::sync::Arc; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; @@ -39,6 +40,8 @@ pub struct GoogleProvider { #[serde(skip)] api_client: ApiClient, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl GoogleProvider { @@ -59,7 +62,11 @@ impl GoogleProvider { let api_client = ApiClient::new(host, auth)?.with_header("Content-Type", "application/json")?; - Ok(Self { api_client, model }) + Ok(Self { + api_client, + model, + metadata: Arc::new(Self::metadata()), + }) } async fn post(&self, model_name: &str, payload: &Value) -> Result { @@ -86,6 +93,10 @@ impl Provider for GoogleProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index bb2d457b5a4a..4930d7ff898a 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -320,6 +320,11 @@ impl Provider for LeadWorkerProvider { ) } + fn get_metadata(&self) -> Arc { + // Return the lead provider's metadata as the default + self.lead_provider.get_metadata() + } + fn get_model_config(&self) -> ModelConfig { // Return the lead provider's model config as the default // In practice, this might need to be more sophisticated @@ -472,6 +477,10 @@ mod tests { ProviderMetadata::empty() } + fn get_metadata(&self) -> Arc { + Arc::new(ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } @@ -634,6 +643,10 @@ mod tests { ProviderMetadata::empty() } + fn get_metadata(&self) -> Arc { + Arc::new(ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index a0839724e8fc..f98bbc2d4b36 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -2,6 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use serde_json::{json, Value}; use std::collections::HashMap; +use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; @@ -23,6 +24,8 @@ pub struct LiteLLMProvider { api_client: ApiClient, base_path: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl LiteLLMProvider { @@ -67,6 +70,7 @@ impl LiteLLMProvider { api_client, base_path, model, + metadata: Arc::new(Self::metadata()), }) } @@ -154,6 +158,10 @@ impl Provider for LiteLLMProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index b355b6f244d6..7c25fed67c6b 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -22,6 +22,7 @@ use regex::Regex; use rmcp::model::Tool; use serde_json::{json, Value}; use std::io; +use std::sync::Arc; use std::time::Duration; use tokio::pin; use tokio_stream::StreamExt; @@ -47,6 +48,8 @@ pub struct OllamaProvider { api_client: ApiClient, model: ModelConfig, supports_streaming: bool, + #[serde(skip)] + metadata: Arc, } impl OllamaProvider { @@ -92,6 +95,7 @@ impl OllamaProvider { api_client, model, supports_streaming: true, + metadata: Arc::new(Self::metadata()), }) } @@ -131,6 +135,7 @@ impl OllamaProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), + metadata: Arc::new(Self::metadata()), }) } @@ -176,6 +181,10 @@ impl Provider for OllamaProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index c00d8e6a5ed5..c64b24a85e85 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -6,6 +6,7 @@ use reqwest::StatusCode; use serde_json::{json, Value}; use std::collections::HashMap; use std::io; +use std::sync::Arc; use tokio::pin; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -54,6 +55,8 @@ pub struct OpenAiProvider { model: ModelConfig, custom_headers: Option>, supports_streaming: bool, + #[serde(skip)] + metadata: Arc, } impl OpenAiProvider { @@ -107,6 +110,7 @@ impl OpenAiProvider { model, custom_headers, supports_streaming: true, + metadata: Arc::new(Self::metadata()), }) } @@ -163,6 +167,7 @@ impl OpenAiProvider { model, custom_headers: config.headers, supports_streaming: config.supports_streaming.unwrap_or(true), + metadata: Arc::new(Self::metadata()), }) } @@ -201,6 +206,10 @@ impl Provider for OpenAiProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index f9acb72152f6..e5112c847dee 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,6 +1,7 @@ use anyhow::{Error, Result}; use async_trait::async_trait; use serde_json::{json, Value}; +use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -40,6 +41,8 @@ pub struct OpenRouterProvider { #[serde(skip)] api_client: ApiClient, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl OpenRouterProvider { @@ -57,7 +60,11 @@ impl OpenRouterProvider { .with_header("HTTP-Referer", "https://block.github.io/goose")? .with_header("X-Title", "goose")?; - Ok(Self { api_client, model }) + Ok(Self { + api_client, + model, + metadata: Arc::new(Self::metadata()), + }) } async fn post(&self, payload: &Value) -> Result { @@ -242,6 +249,10 @@ impl Provider for OpenRouterProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 9c804147cbd1..c86aead47a58 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -30,6 +31,8 @@ pub struct SageMakerTgiProvider { sagemaker_client: SageMakerClient, endpoint_name: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl SageMakerTgiProvider { @@ -79,6 +82,7 @@ impl SageMakerTgiProvider { sagemaker_client, endpoint_name, model, + metadata: Arc::new(Self::metadata()), }) } @@ -272,6 +276,10 @@ impl Provider for SageMakerTgiProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 7138625aaaea..0e8c2c57ab53 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -2,6 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; @@ -48,6 +49,8 @@ pub struct SnowflakeProvider { api_client: ApiClient, model: ModelConfig, image_format: ImageFormat, + #[serde(skip)] + metadata: Arc, } impl SnowflakeProvider { @@ -101,6 +104,7 @@ impl SnowflakeProvider { api_client, model, image_format: ImageFormat::OpenAi, + metadata: Arc::new(Self::metadata()), }) } @@ -302,6 +306,10 @@ impl Provider for SnowflakeProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index e15cd4a0fc9e..52bca052fc19 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -36,6 +36,7 @@ pub struct TestProvider { inner: Option>, records: Arc>>, file_path: String, + metadata: Arc, } impl TestProvider { @@ -44,6 +45,7 @@ impl TestProvider { inner: Some(inner), records: Arc::new(Mutex::new(HashMap::new())), file_path: file_path.into(), + metadata: Arc::new(Self::metadata()), } } @@ -55,6 +57,7 @@ impl TestProvider { inner: None, records: Arc::new(Mutex::new(records)), file_path, + metadata: Arc::new(Self::metadata()), }) } @@ -112,6 +115,10 @@ impl Provider for TestProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + async fn complete_with_model( &self, _model_config: &ModelConfig, @@ -189,6 +196,10 @@ mod tests { ) } + fn get_metadata(&self) -> Arc { + Arc::new(Self::metadata()) + } + async fn complete_with_model( &self, _model_config: &ModelConfig, diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index e771d4916019..d686ae13816e 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use futures::TryStreamExt; use serde_json::{json, Value}; use std::io; +use std::sync::Arc; use tokio::pin; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -46,6 +47,8 @@ pub struct TetrateProvider { api_client: ApiClient, model: ModelConfig, supports_streaming: bool, + #[serde(skip)] + metadata: Arc, } impl TetrateProvider { @@ -66,6 +69,7 @@ impl TetrateProvider { api_client, model, supports_streaming: true, + metadata: Arc::new(Self::metadata()), }) } @@ -150,6 +154,10 @@ impl Provider for TetrateProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 701af251be4e..bcef89614402 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use chrono::Utc; use serde::Serialize; use serde_json::{json, Value}; +use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -78,6 +79,8 @@ pub struct VeniceProvider { base_path: String, models_path: String, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl VeniceProvider { @@ -105,6 +108,7 @@ impl VeniceProvider { base_path, models_path, model, + metadata: Arc::new(Self::metadata()), }; Ok(instance) @@ -210,6 +214,10 @@ impl Provider for VeniceProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index ab75680acd64..300d65859fee 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -3,6 +3,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat, RequestLog}; use crate::conversation::message::Message; +use std::sync::Arc; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -42,6 +43,8 @@ pub struct XaiProvider { #[serde(skip)] api_client: ApiClient, model: ModelConfig, + #[serde(skip)] + metadata: Arc, } impl XaiProvider { @@ -55,7 +58,11 @@ impl XaiProvider { let auth = AuthMethod::BearerToken(api_key); let api_client = ApiClient::new(host, auth)?; - Ok(Self { api_client, model }) + Ok(Self { + api_client, + model, + metadata: Arc::new(Self::metadata()), + }) } async fn post(&self, payload: Value) -> Result { @@ -87,6 +94,10 @@ impl Provider for XaiProvider { ) } + fn get_metadata(&self) -> Arc { + Arc::clone(&self.metadata) + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index e62ed05e49f5..d8e1f557141f 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1361,6 +1361,10 @@ mod tests { ) } + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(Self::metadata()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index e97f0bfb53c8..6fb7864c0320 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -557,6 +557,10 @@ mod final_output_tool_tests { goose::providers::base::ProviderMetadata::empty() } + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } @@ -672,6 +676,10 @@ mod final_output_tool_tests { goose::providers::base::ProviderMetadata::empty() } + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } @@ -858,6 +866,10 @@ mod retry_tests { goose::providers::base::ProviderMetadata::empty() } + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + } + fn get_model_config(&self) -> ModelConfig { self.model_config.clone() } @@ -1080,6 +1092,10 @@ mod max_turns_tests { config_keys: vec![], } } + + fn get_metadata(&self) -> std::sync::Arc { + std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + } } #[tokio::test] From 239c4a43c6570a78a930d5b68e63da6d6f15134e Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Mon, 27 Oct 2025 17:16:01 -0400 Subject: [PATCH 2/3] Only store provider name instead of full metadata to support custom providers --- crates/goose/src/agents/reply_parts.rs | 4 +-- .../goose/src/config/declarative_providers.rs | 27 ++++++++++++++++--- crates/goose/src/providers/anthropic.rs | 13 +++++---- crates/goose/src/providers/azure.rs | 9 +++---- crates/goose/src/providers/base.rs | 6 ++--- crates/goose/src/providers/bedrock.rs | 9 +++---- crates/goose/src/providers/claude_code.rs | 9 +++---- crates/goose/src/providers/cursor_agent.rs | 9 +++---- crates/goose/src/providers/databricks.rs | 11 ++++---- crates/goose/src/providers/gcpvertexai.rs | 9 +++---- crates/goose/src/providers/gemini_cli.rs | 9 +++---- crates/goose/src/providers/githubcopilot.rs | 9 +++---- crates/goose/src/providers/google.rs | 9 +++---- crates/goose/src/providers/lead_worker.rs | 14 +++++----- crates/goose/src/providers/litellm.rs | 9 +++---- crates/goose/src/providers/ollama.rs | 13 +++++---- crates/goose/src/providers/openai.rs | 13 +++++---- crates/goose/src/providers/openrouter.rs | 9 +++---- crates/goose/src/providers/sagemaker_tgi.rs | 9 +++---- crates/goose/src/providers/snowflake.rs | 9 +++---- crates/goose/src/providers/testprovider.rs | 14 +++++----- crates/goose/src/providers/tetrate.rs | 9 +++---- crates/goose/src/providers/venice.rs | 9 +++---- crates/goose/src/providers/xai.rs | 9 +++---- crates/goose/src/scheduler.rs | 4 +-- crates/goose/tests/agent.rs | 16 +++++------ 26 files changed, 136 insertions(+), 134 deletions(-) diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 7d3cd1f74c26..1d8fdaea06b4 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -309,8 +309,8 @@ mod tests { crate::providers::base::ProviderMetadata::empty() } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(crate::providers::base::ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock" } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/config/declarative_providers.rs b/crates/goose/src/config/declarative_providers.rs index 10df3ff1021b..5255a00239b6 100644 --- a/crates/goose/src/config/declarative_providers.rs +++ b/crates/goose/src/config/declarative_providers.rs @@ -293,24 +293,45 @@ pub fn register_declarative_provider( match config.engine { ProviderEngine::OpenAI => { + let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| OpenAiProvider::from_custom_config(model, config_clone.clone()), + move |model| { + OpenAiProvider::from_custom_config( + model, + config_clone.clone(), + provider_name.clone(), + ) + }, ); } ProviderEngine::Ollama => { + let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| OllamaProvider::from_custom_config(model, config_clone.clone()), + move |model| { + OllamaProvider::from_custom_config( + model, + config_clone.clone(), + provider_name.clone(), + ) + }, ); } ProviderEngine::Anthropic => { + let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| AnthropicProvider::from_custom_config(model, config_clone.clone()), + move |model| { + AnthropicProvider::from_custom_config( + model, + config_clone.clone(), + provider_name.clone(), + ) + }, ); } } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 0f0c3f8e19b7..5a47c128d97c 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -5,7 +5,6 @@ use futures::TryStreamExt; use reqwest::StatusCode; use serde_json::Value; use std::io; -use std::sync::Arc; use tokio::pin; use tokio_util::io::StreamReader; @@ -44,8 +43,7 @@ pub struct AnthropicProvider { api_client: ApiClient, model: ModelConfig, supports_streaming: bool, - #[serde(skip)] - metadata: Arc, + name: String, } impl AnthropicProvider { @@ -70,13 +68,14 @@ impl AnthropicProvider { api_client, model, supports_streaming: true, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, + provider_name: String, ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config @@ -95,7 +94,7 @@ impl AnthropicProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), - metadata: Arc::new(Self::metadata()), + name: provider_name, }) } @@ -181,8 +180,8 @@ impl Provider for AnthropicProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 2d49646afd20..d519a82c7345 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -2,7 +2,6 @@ use anyhow::Result; use async_trait::async_trait; use serde::Serialize; use serde_json::Value; -use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::azureauth::{AuthError, AzureAuth}; @@ -28,7 +27,7 @@ pub struct AzureProvider { deployment_name: String, api_version: String, model: ModelConfig, - metadata: Arc, + name: String, } impl Serialize for AzureProvider { @@ -96,7 +95,7 @@ impl AzureProvider { deployment_name, api_version, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -131,8 +130,8 @@ impl Provider for AzureProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index bf0c3415938e..0067af9fc8f8 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -14,7 +14,7 @@ use utoipa::ToSchema; use once_cell::sync::Lazy; use std::ops::{Add, AddAssign}; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; /// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias pub static CURRENT_MODEL: Lazy>> = Lazy::new(|| Mutex::new(None)); @@ -325,8 +325,8 @@ pub trait Provider: Send + Sync { where Self: Sized; - /// Get the metadata for this provider instance - fn get_metadata(&self) -> Arc; + /// Get the name of this provider instance + fn get_name(&self) -> &str; // Internal implementation of complete, used by complete_fast and complete // Providers should override this to implement their actual completion logic diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 7a17ce95c23b..ff5300484316 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; @@ -44,7 +43,7 @@ pub struct BedrockProvider { #[serde(skip)] retry_config: RetryConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl BedrockProvider { @@ -81,7 +80,7 @@ impl BedrockProvider { client, model, retry_config, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -188,8 +187,8 @@ impl Provider for BedrockProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn retry_config(&self) -> RetryConfig { diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 8dd71f3334ca..9683c2ee03d4 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -4,7 +4,6 @@ use rmcp::model::Role; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; -use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -26,7 +25,7 @@ pub struct ClaudeCodeProvider { command: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl ClaudeCodeProvider { @@ -45,7 +44,7 @@ impl ClaudeCodeProvider { Ok(Self { command: resolved_command, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -467,8 +466,8 @@ impl Provider for ClaudeCodeProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 6f467dcdc0c5..31b185f3ddae 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -4,7 +4,6 @@ use rmcp::model::Role; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; -use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -25,7 +24,7 @@ pub struct CursorAgentProvider { command: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl CursorAgentProvider { @@ -44,7 +43,7 @@ impl CursorAgentProvider { Ok(Self { command: resolved_command, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -399,8 +398,8 @@ impl Provider for CursorAgentProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 9ed485f363fd..b36b07f1484e 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -5,7 +5,6 @@ use futures::TryStreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io; -use std::sync::Arc; use std::time::Duration; use tokio::pin; use tokio_util::io::StreamReader; @@ -108,7 +107,7 @@ pub struct DatabricksProvider { #[serde(skip)] retry_config: RetryConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl DatabricksProvider { @@ -149,7 +148,7 @@ impl DatabricksProvider { model: model.clone(), image_format: ImageFormat::OpenAi, retry_config, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }; // Check if the default fast model exists in the workspace @@ -226,7 +225,7 @@ impl DatabricksProvider { model, image_format: ImageFormat::OpenAi, retry_config: RetryConfig::default(), - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -265,8 +264,8 @@ impl Provider for DatabricksProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn retry_config(&self) -> RetryConfig { diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index a56440662621..62cf82289a86 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -1,4 +1,3 @@ -use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -78,7 +77,7 @@ pub struct GcpVertexAIProvider { #[serde(skip)] retry_config: RetryConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl GcpVertexAIProvider { @@ -112,7 +111,7 @@ impl GcpVertexAIProvider { location, model, retry_config, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -498,8 +497,8 @@ impl Provider for GcpVertexAIProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } /// Completes a model interaction by sending a request and processing the response. diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 307dfb2d6b85..e13610e7dd09 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -3,7 +3,6 @@ use async_trait::async_trait; use serde_json::json; use std::path::PathBuf; use std::process::Stdio; -use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; @@ -26,7 +25,7 @@ pub struct GeminiCliProvider { command: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl GeminiCliProvider { @@ -45,7 +44,7 @@ impl GeminiCliProvider { Ok(Self { command: resolved_command, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -315,8 +314,8 @@ impl Provider for GeminiCliProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 2418c82ec002..08a7074c48ba 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -9,7 +9,6 @@ use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::path::PathBuf; -use std::sync::Arc; use std::time::Duration; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -115,7 +114,7 @@ pub struct GithubCopilotProvider { mu: tokio::sync::Mutex>>, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl GithubCopilotProvider { @@ -130,7 +129,7 @@ impl GithubCopilotProvider { cache, mu, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -396,8 +395,8 @@ impl Provider for GithubCopilotProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 54c4305ab1b1..6f73268aa5e2 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,7 +3,6 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{handle_response_google_compat, unescape_json_values, RequestLog}; use crate::conversation::message::Message; -use std::sync::Arc; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; @@ -41,7 +40,7 @@ pub struct GoogleProvider { api_client: ApiClient, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl GoogleProvider { @@ -65,7 +64,7 @@ impl GoogleProvider { Ok(Self { api_client, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -93,8 +92,8 @@ impl Provider for GoogleProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 4930d7ff898a..1dc8fe7a8b7c 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -320,9 +320,9 @@ impl Provider for LeadWorkerProvider { ) } - fn get_metadata(&self) -> Arc { - // Return the lead provider's metadata as the default - self.lead_provider.get_metadata() + fn get_name(&self) -> &str { + // Return the lead provider's name as the default + self.lead_provider.get_name() } fn get_model_config(&self) -> ModelConfig { @@ -477,8 +477,8 @@ mod tests { ProviderMetadata::empty() } - fn get_metadata(&self) -> Arc { - Arc::new(ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-lead" } fn get_model_config(&self) -> ModelConfig { @@ -643,8 +643,8 @@ mod tests { ProviderMetadata::empty() } - fn get_metadata(&self) -> Arc { - Arc::new(ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-lead" } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index f98bbc2d4b36..65bebce57504 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -2,7 +2,6 @@ use anyhow::Result; use async_trait::async_trait; use serde_json::{json, Value}; use std::collections::HashMap; -use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; @@ -25,7 +24,7 @@ pub struct LiteLLMProvider { base_path: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl LiteLLMProvider { @@ -70,7 +69,7 @@ impl LiteLLMProvider { api_client, base_path, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -158,8 +157,8 @@ impl Provider for LiteLLMProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 7c25fed67c6b..98b1d9f826e5 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -22,7 +22,6 @@ use regex::Regex; use rmcp::model::Tool; use serde_json::{json, Value}; use std::io; -use std::sync::Arc; use std::time::Duration; use tokio::pin; use tokio_stream::StreamExt; @@ -48,8 +47,7 @@ pub struct OllamaProvider { api_client: ApiClient, model: ModelConfig, supports_streaming: bool, - #[serde(skip)] - metadata: Arc, + name: String, } impl OllamaProvider { @@ -95,13 +93,14 @@ impl OllamaProvider { api_client, model, supports_streaming: true, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, + provider_name: String, ) -> Result { let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); @@ -135,7 +134,7 @@ impl OllamaProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), - metadata: Arc::new(Self::metadata()), + name: provider_name, }) } @@ -181,8 +180,8 @@ impl Provider for OllamaProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index c64b24a85e85..e347f6db091b 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -6,7 +6,6 @@ use reqwest::StatusCode; use serde_json::{json, Value}; use std::collections::HashMap; use std::io; -use std::sync::Arc; use tokio::pin; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -55,8 +54,7 @@ pub struct OpenAiProvider { model: ModelConfig, custom_headers: Option>, supports_streaming: bool, - #[serde(skip)] - metadata: Arc, + name: String, } impl OpenAiProvider { @@ -110,13 +108,14 @@ impl OpenAiProvider { model, custom_headers, supports_streaming: true, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, + provider_name: String, ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config @@ -167,7 +166,7 @@ impl OpenAiProvider { model, custom_headers: config.headers, supports_streaming: config.supports_streaming.unwrap_or(true), - metadata: Arc::new(Self::metadata()), + name: provider_name, }) } @@ -206,8 +205,8 @@ impl Provider for OpenAiProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index e5112c847dee..e87fb3741e44 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,7 +1,6 @@ use anyhow::{Error, Result}; use async_trait::async_trait; use serde_json::{json, Value}; -use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -42,7 +41,7 @@ pub struct OpenRouterProvider { api_client: ApiClient, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl OpenRouterProvider { @@ -63,7 +62,7 @@ impl OpenRouterProvider { Ok(Self { api_client, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -249,8 +248,8 @@ impl Provider for OpenRouterProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index c86aead47a58..5861b09cdd59 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -32,7 +31,7 @@ pub struct SageMakerTgiProvider { endpoint_name: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl SageMakerTgiProvider { @@ -82,7 +81,7 @@ impl SageMakerTgiProvider { sagemaker_client, endpoint_name, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -276,8 +275,8 @@ impl Provider for SageMakerTgiProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 0e8c2c57ab53..7176d59d2055 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -2,7 +2,6 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; @@ -50,7 +49,7 @@ pub struct SnowflakeProvider { model: ModelConfig, image_format: ImageFormat, #[serde(skip)] - metadata: Arc, + name: String, } impl SnowflakeProvider { @@ -104,7 +103,7 @@ impl SnowflakeProvider { api_client, model, image_format: ImageFormat::OpenAi, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -306,8 +305,8 @@ impl Provider for SnowflakeProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index 52bca052fc19..c9e455bd69d3 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -36,7 +36,7 @@ pub struct TestProvider { inner: Option>, records: Arc>>, file_path: String, - metadata: Arc, + name: String, } impl TestProvider { @@ -45,7 +45,7 @@ impl TestProvider { inner: Some(inner), records: Arc::new(Mutex::new(HashMap::new())), file_path: file_path.into(), - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, } } @@ -57,7 +57,7 @@ impl TestProvider { inner: None, records: Arc::new(Mutex::new(records)), file_path, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -115,8 +115,8 @@ impl Provider for TestProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } async fn complete_with_model( @@ -196,8 +196,8 @@ mod tests { ) } - fn get_metadata(&self) -> Arc { - Arc::new(Self::metadata()) + fn get_name(&self) -> &str { + "mock-testprovider" } async fn complete_with_model( diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index d686ae13816e..235f7b3198a7 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -4,7 +4,6 @@ use async_trait::async_trait; use futures::TryStreamExt; use serde_json::{json, Value}; use std::io; -use std::sync::Arc; use tokio::pin; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -48,7 +47,7 @@ pub struct TetrateProvider { model: ModelConfig, supports_streaming: bool, #[serde(skip)] - metadata: Arc, + name: String, } impl TetrateProvider { @@ -69,7 +68,7 @@ impl TetrateProvider { api_client, model, supports_streaming: true, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -154,8 +153,8 @@ impl Provider for TetrateProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index bcef89614402..4a699222d4a5 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -3,7 +3,6 @@ use async_trait::async_trait; use chrono::Utc; use serde::Serialize; use serde_json::{json, Value}; -use std::sync::Arc; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -80,7 +79,7 @@ pub struct VeniceProvider { models_path: String, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl VeniceProvider { @@ -108,7 +107,7 @@ impl VeniceProvider { base_path, models_path, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }; Ok(instance) @@ -214,8 +213,8 @@ impl Provider for VeniceProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 300d65859fee..0078d9894e97 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -3,7 +3,6 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat, RequestLog}; use crate::conversation::message::Message; -use std::sync::Arc; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -44,7 +43,7 @@ pub struct XaiProvider { api_client: ApiClient, model: ModelConfig, #[serde(skip)] - metadata: Arc, + name: String, } impl XaiProvider { @@ -61,7 +60,7 @@ impl XaiProvider { Ok(Self { api_client, model, - metadata: Arc::new(Self::metadata()), + name: Self::metadata().name, }) } @@ -94,8 +93,8 @@ impl Provider for XaiProvider { ) } - fn get_metadata(&self) -> Arc { - Arc::clone(&self.metadata) + fn get_name(&self) -> &str { + &self.name } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index d8e1f557141f..85269ae2fa37 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1361,8 +1361,8 @@ mod tests { ) } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(Self::metadata()) + fn get_name(&self) -> &str { + "mock-scheduler" } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 6fb7864c0320..77e7359f6987 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -557,8 +557,8 @@ mod final_output_tool_tests { goose::providers::base::ProviderMetadata::empty() } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-test" } fn get_model_config(&self) -> ModelConfig { @@ -676,8 +676,8 @@ mod final_output_tool_tests { goose::providers::base::ProviderMetadata::empty() } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-test" } fn get_model_config(&self) -> ModelConfig { @@ -866,8 +866,8 @@ mod retry_tests { goose::providers::base::ProviderMetadata::empty() } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-test" } fn get_model_config(&self) -> ModelConfig { @@ -1093,8 +1093,8 @@ mod max_turns_tests { } } - fn get_metadata(&self) -> std::sync::Arc { - std::sync::Arc::new(goose::providers::base::ProviderMetadata::empty()) + fn get_name(&self) -> &str { + "mock-test" } } From 8eb93bad1ab77de7e732b9a063ca20a54368f79e Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Tue, 28 Oct 2025 13:48:21 -0400 Subject: [PATCH 3/3] don't pass custom provider name twice for no reason --- .../goose/src/config/declarative_providers.rs | 27 +++---------------- crates/goose/src/providers/anthropic.rs | 3 +-- crates/goose/src/providers/ollama.rs | 3 +-- crates/goose/src/providers/openai.rs | 3 +-- 4 files changed, 6 insertions(+), 30 deletions(-) diff --git a/crates/goose/src/config/declarative_providers.rs b/crates/goose/src/config/declarative_providers.rs index 5255a00239b6..10df3ff1021b 100644 --- a/crates/goose/src/config/declarative_providers.rs +++ b/crates/goose/src/config/declarative_providers.rs @@ -293,45 +293,24 @@ pub fn register_declarative_provider( match config.engine { ProviderEngine::OpenAI => { - let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| { - OpenAiProvider::from_custom_config( - model, - config_clone.clone(), - provider_name.clone(), - ) - }, + move |model| OpenAiProvider::from_custom_config(model, config_clone.clone()), ); } ProviderEngine::Ollama => { - let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| { - OllamaProvider::from_custom_config( - model, - config_clone.clone(), - provider_name.clone(), - ) - }, + move |model| OllamaProvider::from_custom_config(model, config_clone.clone()), ); } ProviderEngine::Anthropic => { - let provider_name = config.name.clone(); registry.register_with_name::( &config, provider_type, - move |model| { - AnthropicProvider::from_custom_config( - model, - config_clone.clone(), - provider_name.clone(), - ) - }, + move |model| AnthropicProvider::from_custom_config(model, config_clone.clone()), ); } } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 5a47c128d97c..06dbd8fecfad 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -75,7 +75,6 @@ impl AnthropicProvider { pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, - provider_name: String, ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config @@ -94,7 +93,7 @@ impl AnthropicProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), - name: provider_name, + name: config.name.clone(), }) } diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 98b1d9f826e5..3a180084634b 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -100,7 +100,6 @@ impl OllamaProvider { pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, - provider_name: String, ) -> Result { let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); @@ -134,7 +133,7 @@ impl OllamaProvider { api_client, model, supports_streaming: config.supports_streaming.unwrap_or(true), - name: provider_name, + name: config.name.clone(), }) } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index e347f6db091b..720cebce136a 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -115,7 +115,6 @@ impl OpenAiProvider { pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, - provider_name: String, ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config @@ -166,7 +165,7 @@ impl OpenAiProvider { model, custom_headers: config.headers, supports_streaming: config.supports_streaming.unwrap_or(true), - name: provider_name, + name: config.name.clone(), }) }