diff --git a/clippy-baselines/too_many_lines.txt b/clippy-baselines/too_many_lines.txt index e69d0ef3a1ce..bf3d3f536b72 100644 --- a/clippy-baselines/too_many_lines.txt +++ b/clippy-baselines/too_many_lines.txt @@ -16,6 +16,7 @@ crates/goose/src/agents/agent.rs::create_recipe crates/goose/src/agents/agent.rs::dispatch_tool_call crates/goose/src/agents/agent.rs::reply crates/goose/src/agents/agent.rs::reply_internal +crates/goose/src/providers/canonical/build_canonical_models.rs::build_canonical_models crates/goose/src/providers/formats/anthropic.rs::format_messages crates/goose/src/providers/formats/anthropic.rs::response_to_streaming_message crates/goose/src/providers/formats/databricks.rs::format_messages diff --git a/crates/goose-acp/tests/fixtures/mod.rs b/crates/goose-acp/tests/fixtures/mod.rs index f79cbf8d16fa..51bd708d39e8 100644 --- a/crates/goose-acp/tests/fixtures/mod.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -1,3 +1,6 @@ +#![recursion_limit = "256"] +#![allow(unused_attributes)] + use assert_json_diff::{assert_json_matches_no_panic, CompareMode, Config}; use async_trait::async_trait; use fs_err as fs; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 854d87eca129..15664f6473ca 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -431,10 +431,6 @@ mod tests { #[async_trait] impl Provider for MockProvider { - fn metadata() -> crate::providers::base::ProviderMetadata { - crate::providers::base::ProviderMetadata::empty() - } - fn get_name(&self) -> &str { "mock" } diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index fb45061b61ca..ddb32b3bdfb3 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -537,10 +537,7 @@ mod tests { use super::*; use crate::{ model::ModelConfig, - providers::{ - base::{ProviderMetadata, Usage}, - errors::ProviderError, - }, + providers::{base::Usage, errors::ProviderError}, }; use async_trait::async_trait; use rmcp::model::{AnnotateAble, CallToolRequestParams, RawContent, Tool}; @@ -577,10 +574,6 @@ mod tests { #[async_trait] impl Provider for MockProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::new("mock", "", "", "", vec![""], "", vec![]) - } - fn get_name(&self) -> &str { "mock" } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 82e204a66213..dbc2026e0877 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -9,19 +9,25 @@ use tokio::pin; use tokio_util::io::StreamReader; use super::api_client::{ApiClient, ApiResponse, AuthMethod}; -use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ + ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use super::errors::ProviderError; use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; -use super::utils::{get_model, handle_status_openai_compat, map_http_error_to_provider_error}; +use super::openai_compatible::handle_status_openai_compat; +use super::openai_compatible::map_http_error_to_provider_error; +use super::utils::get_model; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::retry::ProviderRetry; use crate::providers::utils::RequestLog; +use futures::future::BoxFuture; use rmcp::model::Tool; +const ANTHROPIC_PROVIDER_NAME: &str = "anthropic"; pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-5"; const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-haiku-4-5"; const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ @@ -73,7 +79,7 @@ impl AnthropicProvider { api_client, model, supports_streaming: true, - name: Self::metadata().name, + name: ANTHROPIC_PROVIDER_NAME.to_string(), }) } @@ -171,8 +177,9 @@ impl AnthropicProvider { } } -#[async_trait] -impl Provider for AnthropicProvider { +impl ProviderDef for AnthropicProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { let models: Vec = ANTHROPIC_KNOWN_MODELS .iter() @@ -180,7 +187,7 @@ impl Provider for AnthropicProvider { .collect(); ProviderMetadata::with_models( - "anthropic", + ANTHROPIC_PROVIDER_NAME, "Anthropic", "Claude and other models from Anthropic", ANTHROPIC_DEFAULT_MODEL, @@ -198,6 +205,13 @@ impl Provider for AnthropicProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for AnthropicProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs index 627f5435d535..113c03837561 100644 --- a/crates/goose/src/providers/api_client.rs +++ b/crates/goose/src/providers/api_client.rs @@ -16,6 +16,7 @@ pub struct ApiClient { host: String, auth: AuthMethod, default_headers: HeaderMap, + default_query: Vec<(String, String)>, timeout: Duration, tls_config: Option, } @@ -222,6 +223,7 @@ impl ApiClient { host, auth, default_headers: HeaderMap::new(), + default_query: Vec::new(), timeout, tls_config, }) @@ -267,6 +269,11 @@ impl ApiClient { Ok(self) } + pub fn with_query(mut self, params: Vec<(String, String)>) -> Self { + self.default_query = params; + self + } + pub fn with_header(mut self, key: &str, value: &str) -> Result { let header_name = HeaderName::from_bytes(key.as_bytes())?; let header_value = HeaderValue::from_str(value)?; @@ -325,9 +332,15 @@ impl ApiClient { base_url.set_path(&format!("{}/", base_path)); } - base_url + let mut url = base_url .join(path) - .map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e)) + .map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e))?; + + for (key, value) in &self.default_query { + url.query_pairs_mut().append_pair(key, value); + } + + Ok(url) } async fn get_oauth_token(&self, config: &OAuthConfig) -> Result { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index cbecb88dd171..160ab71f3569 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -1,47 +1,21 @@ use anyhow::Result; use async_trait::async_trait; -use serde::Serialize; -use serde_json::Value; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::azureauth::{AuthError, AzureAuth}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; -use super::errors::ProviderError; -use super::formats::openai::{create_request, get_usage, response_to_message}; -use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_openai_compat, ImageFormat}; -use crate::conversation::message::Message; +use super::base::{ConfigKey, ProviderDef, ProviderMetadata}; +use super::openai_compatible::OpenAiCompatibleProvider; use crate::model::ModelConfig; -use crate::providers::utils::RequestLog; -use rmcp::model::Tool; +use futures::future::BoxFuture; +const AZURE_PROVIDER_NAME: &str = "azure_openai"; pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; pub const AZURE_DOC_URL: &str = "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; -#[derive(Debug)] -pub struct AzureProvider { - api_client: ApiClient, - deployment_name: String, - api_version: String, - model: ModelConfig, - name: String, -} - -impl Serialize for AzureProvider { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut state = serializer.serialize_struct("AzureProvider", 2)?; - state.serialize_field("deployment_name", &self.deployment_name)?; - state.serialize_field("api_version", &self.api_version)?; - state.end() - } -} +pub struct AzureProvider; // Custom auth provider that wraps AzureAuth struct AzureAuthProvider { @@ -69,60 +43,12 @@ impl AuthProvider for AzureAuthProvider { } } -impl AzureProvider { - pub async fn from_env(model: ModelConfig) -> Result { - let config = crate::config::Config::global(); - let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; - let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?; - let api_version: String = config - .get_param("AZURE_OPENAI_API_VERSION") - .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); - - let api_key = config - .get_secret("AZURE_OPENAI_API_KEY") - .ok() - .filter(|key: &String| !key.is_empty()); - let auth = AzureAuth::new(api_key).map_err(|e| match e { - AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg), - AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg), - })?; - - let auth_provider = AzureAuthProvider { auth }; - let api_client = ApiClient::new(endpoint, AuthMethod::Custom(Box::new(auth_provider)))?; - - Ok(Self { - api_client, - deployment_name, - api_version, - model, - name: Self::metadata().name, - }) - } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - // Build the path for Azure OpenAI - let path = format!( - "openai/deployments/{}/chat/completions?api-version={}", - self.deployment_name, self.api_version - ); +impl ProviderDef for AzureProvider { + type Provider = OpenAiCompatibleProvider; - let response = self - .api_client - .response_post(session_id, &path, payload) - .await?; - handle_response_openai_compat(response).await - } -} - -#[async_trait] -impl Provider for AzureProvider { fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "azure_openai", + AZURE_PROVIDER_NAME, "Azure OpenAI", "Models through Azure OpenAI Service (uses Azure credential chain by default)", "gpt-4o", @@ -137,49 +63,38 @@ impl Provider for AzureProvider { ) } - fn get_name(&self) -> &str { - &self.name - } - - fn get_model_config(&self) -> ModelConfig { - self.model.clone() - } - - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request( - model_config, - system, - messages, - tools, - &ImageFormat::OpenAi, - false, - )?; - let response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post(session_id, &payload_clone).await - }) - .await?; - - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - let mut log = RequestLog::start(model_config, &payload)?; - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(async move { + let config = crate::config::Config::global(); + let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; + let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?; + let api_version: String = config + .get_param("AZURE_OPENAI_API_VERSION") + .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); + + let api_key = config + .get_secret("AZURE_OPENAI_API_KEY") + .ok() + .filter(|key: &String| !key.is_empty()); + let auth = AzureAuth::new(api_key).map_err(|e| match e { + AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg), + AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg), + })?; + + let auth_provider = AzureAuthProvider { auth }; + let host = format!( + "{}/openai/deployments/{}", + endpoint.trim_end_matches('/'), + deployment_name + ); + let api_client = ApiClient::new(host, AuthMethod::Custom(Box::new(auth_provider)))? + .with_query(vec![("api-version".to_string(), api_version)]); + + Ok(OpenAiCompatibleProvider::new( + AZURE_PROVIDER_NAME.to_string(), + api_client, + model, + )) + }) } } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 6eda09232a36..ab343750e1a9 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use futures::future::BoxFuture; use futures::Stream; use serde::{Deserialize, Serialize}; @@ -343,6 +344,18 @@ impl Usage { use async_trait::async_trait; +pub trait ProviderDef: Send + Sync { + type Provider: Provider + 'static; + + fn metadata() -> ProviderMetadata + where + Self: Sized; + + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> + where + Self: Sized; +} + /// Trait for LeadWorkerProvider-specific functionality pub trait LeadWorkerProviderTrait { /// Get information about the lead and worker models for logging @@ -358,11 +371,6 @@ pub trait LeadWorkerProviderTrait { /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] pub trait Provider: Send + Sync { - /// Get the metadata for this provider type - fn metadata() -> ProviderMetadata - where - Self: Sized; - /// Get the name of this provider instance fn get_name(&self) -> &str; diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index debfc9098019..3551a54c9835 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::retry::{ProviderRetry, RetryConfig}; use crate::conversation::message::Message; @@ -11,6 +11,7 @@ use async_trait::async_trait; use aws_sdk_bedrockruntime::config::ProvideCredentials; use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; +use futures::future::BoxFuture; use reqwest::header::HeaderValue; use rmcp::model::Tool; use serde_json::Value; @@ -21,6 +22,7 @@ use super::formats::bedrock::{ }; use crate::session_context::SESSION_ID_HEADER; +const BEDROCK_PROVIDER_NAME: &str = "aws_bedrock"; pub const BEDROCK_DOC_LINK: &str = "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; @@ -140,7 +142,7 @@ impl BedrockProvider { client, model, retry_config, - name: Self::metadata().name, + name: BEDROCK_PROVIDER_NAME.to_string(), }) } @@ -261,11 +263,12 @@ impl BedrockProvider { } } -#[async_trait] -impl Provider for BedrockProvider { +impl ProviderDef for BedrockProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "aws_bedrock", + BEDROCK_PROVIDER_NAME, "Amazon Bedrock", "Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile ' before using. Configure with AWS_PROFILE and AWS_REGION, use environment variables/credentials, or use AWS_BEARER_TOKEN_BEDROCK for bearer token authentication. Region is required for bearer token auth (can be set via AWS_REGION, AWS_DEFAULT_REGION, or AWS profile).", BEDROCK_DEFAULT_MODEL, @@ -279,6 +282,13 @@ impl Provider for BedrockProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for BedrockProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index 53491d3c96dc..603d2eeeaa91 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -2,11 +2,13 @@ use crate::config::paths::Paths; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::api_client::AuthProvider; -use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage}; +use crate::providers::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use crate::providers::errors::ProviderError; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; +use crate::providers::openai_compatible::handle_status_openai_compat; use crate::providers::retry::ProviderRetry; -use crate::providers::utils::handle_status_openai_compat; use crate::session_context::SESSION_ID_HEADER; use anyhow::{anyhow, Result}; use async_stream::try_stream; @@ -14,6 +16,7 @@ use async_trait::async_trait; use axum::{extract::Query, response::Html, routing::get, Router}; use base64::Engine; use chrono::{DateTime, Utc}; +use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use jsonwebtoken::jwk::JwkSet; use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; @@ -43,6 +46,7 @@ const OAUTH_PORT: u16 = 1455; const OAUTH_TIMEOUT_SECS: u64 = 300; const HTML_AUTO_CLOSE_TIMEOUT_MS: u64 = 2000; +const CHATGPT_CODEX_PROVIDER_NAME: &str = "chatgpt_codex"; pub const CHATGPT_CODEX_DEFAULT_MODEL: &str = "gpt-5.1-codex"; pub const CHATGPT_CODEX_KNOWN_MODELS: &[&str] = &[ "gpt-5.2-codex", @@ -787,7 +791,7 @@ impl ChatGptCodexProvider { Ok(Self { auth_provider, model, - name: Self::metadata().name, + name: CHATGPT_CODEX_PROVIDER_NAME.to_string(), }) } @@ -837,11 +841,12 @@ impl ChatGptCodexProvider { } } -#[async_trait] -impl Provider for ChatGptCodexProvider { +impl ProviderDef for ChatGptCodexProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "chatgpt_codex", + CHATGPT_CODEX_PROVIDER_NAME, "ChatGPT Codex", "Use your ChatGPT Plus/Pro subscription for GPT-5 Codex models via OAuth", CHATGPT_CODEX_DEFAULT_MODEL, @@ -856,6 +861,13 @@ impl Provider for ChatGptCodexProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for ChatGptCodexProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 594c0bcf1855..1e6417c19dc7 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -8,7 +8,7 @@ use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::ClaudeCodeCommand; @@ -17,8 +17,10 @@ use crate::config::{Config, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_command_no_window; +use futures::future::BoxFuture; use rmcp::model::Tool; +const CLAUDE_CODE_PROVIDER_NAME: &str = "claude-code"; pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["sonnet", "opus"]; pub const CLAUDE_CODE_DOC_URL: &str = "https://code.claude.com/docs/en/setup"; @@ -40,7 +42,7 @@ impl ClaudeCodeProvider { Ok(Self { command: resolved_command, model, - name: Self::metadata().name, + name: CLAUDE_CODE_PROVIDER_NAME.to_string(), }) } @@ -388,11 +390,12 @@ impl ClaudeCodeProvider { } } -#[async_trait] -impl Provider for ClaudeCodeProvider { +impl ProviderDef for ClaudeCodeProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "claude-code", + CLAUDE_CODE_PROVIDER_NAME, "Claude Code CLI", "Requires claude CLI installed, no MCPs. Use Anthropic provider for full features.", CLAUDE_CODE_DEFAULT_MODEL, @@ -402,6 +405,13 @@ impl Provider for ClaudeCodeProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for ClaudeCodeProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index c72afc90ca12..6234483a22ce 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -7,7 +7,7 @@ use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::{CodexCommand, CodexReasoningEffort, CodexSkipGitCheck}; @@ -16,9 +16,11 @@ use crate::config::{Config, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_command_no_window; +use futures::future::BoxFuture; use rmcp::model::Role; use rmcp::model::Tool; +const CODEX_PROVIDER_NAME: &str = "codex"; pub const CODEX_DEFAULT_MODEL: &str = "gpt-5.2-codex"; pub const CODEX_KNOWN_MODELS: &[&str] = &[ "gpt-5.2-codex", @@ -77,7 +79,7 @@ impl CodexProvider { Ok(Self { command: resolved_command, model, - name: Self::metadata().name, + name: CODEX_PROVIDER_NAME.to_string(), reasoning_effort, skip_git_check, }) @@ -471,11 +473,12 @@ impl CodexProvider { } } -#[async_trait] -impl Provider for CodexProvider { +impl ProviderDef for CodexProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "codex", + CODEX_PROVIDER_NAME, "OpenAI Codex CLI", "Execute OpenAI models via Codex CLI tool. Requires codex CLI installed.", CODEX_DEFAULT_MODEL, @@ -489,6 +492,13 @@ impl Provider for CodexProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for CodexProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 6b31e2b47c48..cadd7f97a363 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -8,7 +8,7 @@ use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::CursorAgentCommand; @@ -16,8 +16,10 @@ use crate::config::search_path::SearchPaths; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_command_no_window; +use futures::future::BoxFuture; use rmcp::model::Tool; +const CURSOR_AGENT_PROVIDER_NAME: &str = "cursor-agent"; pub const CURSOR_AGENT_DEFAULT_MODEL: &str = "auto"; pub const CURSOR_AGENT_KNOWN_MODELS: &[&str] = &["auto", "gpt-5", "opus-4.1", "sonnet-4"]; @@ -40,7 +42,7 @@ impl CursorAgentProvider { Ok(Self { command: resolved_command, model, - name: Self::metadata().name, + name: CURSOR_AGENT_PROVIDER_NAME.to_string(), }) } @@ -321,11 +323,12 @@ impl CursorAgentProvider { } } -#[async_trait] -impl Provider for CursorAgentProvider { +impl ProviderDef for CursorAgentProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "cursor-agent", + CURSOR_AGENT_PROVIDER_NAME, "Cursor Agent", "Execute AI models via cursor-agent CLI tool", CURSOR_AGENT_DEFAULT_MODEL, @@ -337,6 +340,13 @@ impl Provider for CursorAgentProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for CursorAgentProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 7139a2024b62..16fced6ac299 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,20 +1,23 @@ use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::time::Duration; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; -use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::{create_request, response_to_message}; use super::oauth; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_openai_compat, map_http_error_to_provider_error, - stream_openai_compat, ImageFormat, RequestLog, +use super::openai_compatible::{ + handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat, }; +use super::retry::ProviderRetry; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::config::ConfigError; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -31,6 +34,7 @@ const DEFAULT_REDIRECT_URL: &str = "http://localhost"; const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; const DEFAULT_TIMEOUT_SECS: u64 = 600; +const DATABRICKS_PROVIDER_NAME: &str = "databricks"; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4"; const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-2-5-flash"; pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ @@ -140,7 +144,7 @@ impl DatabricksProvider { model: model.clone(), image_format: ImageFormat::OpenAi, retry_config, - name: Self::metadata().name, + name: DATABRICKS_PROVIDER_NAME.to_string(), }; provider.model = model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string()); Ok(provider) @@ -192,7 +196,7 @@ impl DatabricksProvider { model, image_format: ImageFormat::OpenAi, retry_config: RetryConfig::default(), - name: Self::metadata().name, + name: DATABRICKS_PROVIDER_NAME.to_string(), }) } @@ -222,11 +226,12 @@ impl DatabricksProvider { } } -#[async_trait] -impl Provider for DatabricksProvider { +impl ProviderDef for DatabricksProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "databricks", + DATABRICKS_PROVIDER_NAME, "Databricks", "Models on Databricks AI Gateway", DATABRICKS_DEFAULT_MODEL, @@ -239,6 +244,13 @@ impl Provider for DatabricksProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for DatabricksProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index fafda26f180f..195d185f3796 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; +use futures::future::BoxFuture; use futures::StreamExt; use futures::TryStreamExt; use once_cell::sync::Lazy; @@ -15,7 +16,9 @@ use url::Url; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage}; +use crate::providers::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use crate::providers::errors::ProviderError; use crate::providers::formats::gcpvertexai::{ @@ -28,6 +31,7 @@ use crate::providers::utils::RequestLog; use crate::session_context::SESSION_ID_HEADER; use rmcp::model::Tool; +const GCP_VERTEX_AI_PROVIDER_NAME: &str = "gcp_vertex_ai"; /// Base URL for GCP Vertex AI documentation const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai"; /// Default timeout for API requests in seconds @@ -175,7 +179,7 @@ impl GcpVertexAIProvider { location, model, retry_config, - name: Self::metadata().name, + name: GCP_VERTEX_AI_PROVIDER_NAME.to_string(), }) } @@ -541,14 +545,12 @@ impl GcpVertexAIProvider { } } -#[async_trait] -impl Provider for GcpVertexAIProvider { - fn metadata() -> ProviderMetadata - where - Self: Sized, - { +impl ProviderDef for GcpVertexAIProvider { + type Provider = Self; + + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "gcp_vertex_ai", + GCP_VERTEX_AI_PROVIDER_NAME, "GCP Vertex AI", "Access variety of AI models such as Claude, Gemini through Vertex AI", DEFAULT_MODEL, @@ -591,6 +593,13 @@ impl Provider for GcpVertexAIProvider { .with_unlisted_models() } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for GcpVertexAIProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index dbab2d201f80..4adf8b7925bc 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -7,7 +7,7 @@ use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::GeminiCliCommand; @@ -17,9 +17,11 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use crate::subprocess::configure_command_no_window; +use futures::future::BoxFuture; use rmcp::model::Role; use rmcp::model::Tool; +const GEMINI_CLI_PROVIDER_NAME: &str = "gemini-cli"; pub const GEMINI_CLI_DEFAULT_MODEL: &str = "gemini-2.5-pro"; pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &[ "gemini-2.5-pro", @@ -46,7 +48,7 @@ impl GeminiCliProvider { Ok(Self { command: resolved_command, model, - name: Self::metadata().name, + name: GEMINI_CLI_PROVIDER_NAME.to_string(), }) } @@ -235,11 +237,12 @@ impl GeminiCliProvider { } } -#[async_trait] -impl Provider for GeminiCliProvider { +impl ProviderDef for GeminiCliProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "gemini-cli", + GEMINI_CLI_PROVIDER_NAME, "Gemini CLI", "Execute Gemini models via gemini CLI tool", GEMINI_CLI_DEFAULT_MODEL, @@ -249,6 +252,13 @@ impl Provider for GeminiCliProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for GeminiCliProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 7b2d702a4389..821f57dea000 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -1,6 +1,6 @@ use crate::config::paths::Paths; use crate::providers::api_client::{ApiClient, AuthMethod}; -use crate::providers::utils::{handle_status_openai_compat, stream_openai_compat}; +use crate::providers::openai_compatible::{handle_status_openai_compat, stream_openai_compat}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use axum::http; @@ -13,19 +13,22 @@ use std::collections::HashMap; use std::path::PathBuf; use std::time::Duration; -use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::openai_compatible::handle_response_openai_compat; use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog}; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::config::{Config, ConfigError}; use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, MessageStream}; +use futures::future::BoxFuture; use rmcp::model::Tool; +const GITHUB_COPILOT_PROVIDER_NAME: &str = "github_copilot"; pub const GITHUB_COPILOT_DEFAULT_MODEL: &str = "gpt-4.1"; pub const GITHUB_COPILOT_KNOWN_MODELS: &[&str] = &[ "gpt-4.1", @@ -165,7 +168,7 @@ impl GithubCopilotProvider { cache, mu, model, - name: Self::metadata().name, + name: GITHUB_COPILOT_PROVIDER_NAME.to_string(), }) } @@ -376,11 +379,12 @@ impl GithubCopilotProvider { } } -#[async_trait] -impl Provider for GithubCopilotProvider { +impl ProviderDef for GithubCopilotProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "github_copilot", + GITHUB_COPILOT_PROVIDER_NAME, "GitHub Copilot", "GitHub Copilot. Run `goose configure` and select copilot to set up.", GITHUB_COPILOT_DEFAULT_MODEL, @@ -395,6 +399,13 @@ impl Provider for GithubCopilotProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for GithubCopilotProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 7f560412c39c..736340d38ef6 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,20 +1,20 @@ use super::api_client::{ApiClient, AuthMethod}; use super::base::MessageStream; use super::errors::ProviderError; +use super::openai_compatible::handle_status_openai_compat; use super::retry::ProviderRetry; -use super::utils::{ - handle_response_google_compat, handle_status_openai_compat, unescape_json_values, RequestLog, -}; +use super::utils::{handle_response_google_compat, unescape_json_values, RequestLog}; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use crate::providers::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; +use futures::future::BoxFuture; use futures::TryStreamExt; use rmcp::model::Tool; use serde_json::Value; @@ -24,6 +24,7 @@ use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; use tokio_util::io::StreamReader; +const GOOGLE_PROVIDER_NAME: &str = "google"; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-pro"; pub const GOOGLE_DEFAULT_FAST_MODEL: &str = "gemini-2.5-flash"; @@ -87,7 +88,7 @@ impl GoogleProvider { Ok(Self { api_client, model, - name: Self::metadata().name, + name: GOOGLE_PROVIDER_NAME.to_string(), }) } @@ -120,11 +121,12 @@ impl GoogleProvider { } } -#[async_trait] -impl Provider for GoogleProvider { +impl ProviderDef for GoogleProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "google", + GOOGLE_PROVIDER_NAME, "Google Gemini", "Gemini models from Google AI", GOOGLE_DEFAULT_MODEL, @@ -137,6 +139,13 @@ impl Provider for GoogleProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for GoogleProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/init.rs similarity index 82% rename from crates/goose/src/providers/factory.rs rename to crates/goose/src/providers/init.rs index 281f02ce3695..62344c3d5d56 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/init.rs @@ -43,48 +43,27 @@ static REGISTRY: OnceCell> = OnceCell::const_new(); async fn init_registry() -> RwLock { let mut registry = ProviderRegistry::new().with_providers(|registry| { - registry - .register::(|m| Box::pin(AnthropicProvider::from_env(m)), true); - registry.register::(|m| Box::pin(AzureProvider::from_env(m)), false); - registry.register::(|m| Box::pin(BedrockProvider::from_env(m)), false); - registry.register::( - |m| Box::pin(ChatGptCodexProvider::from_env(m)), - true, - ); - registry - .register::(|m| Box::pin(ClaudeCodeProvider::from_env(m)), true); - registry.register::(|m| Box::pin(CodexProvider::from_env(m)), true); - registry.register::( - |m| Box::pin(CursorAgentProvider::from_env(m)), - false, - ); - registry - .register::(|m| Box::pin(DatabricksProvider::from_env(m)), true); - registry.register::( - |m| Box::pin(GcpVertexAIProvider::from_env(m)), - false, - ); - registry - .register::(|m| Box::pin(GeminiCliProvider::from_env(m)), false); - registry.register::( - |m| Box::pin(GithubCopilotProvider::from_env(m)), - false, - ); - registry.register::(|m| Box::pin(GoogleProvider::from_env(m)), true); - registry.register::(|m| Box::pin(LiteLLMProvider::from_env(m)), false); - registry.register::(|m| Box::pin(OllamaProvider::from_env(m)), true); - registry.register::(|m| Box::pin(OpenAiProvider::from_env(m)), true); - registry - .register::(|m| Box::pin(OpenRouterProvider::from_env(m)), true); - registry.register::( - |m| Box::pin(SageMakerTgiProvider::from_env(m)), - false, - ); - registry - .register::(|m| Box::pin(SnowflakeProvider::from_env(m)), false); - registry.register::(|m| Box::pin(TetrateProvider::from_env(m)), true); - registry.register::(|m| Box::pin(VeniceProvider::from_env(m)), false); - registry.register::(|m| Box::pin(XaiProvider::from_env(m)), false); + registry.register::(true); + registry.register::(false); + registry.register::(false); + registry.register::(true); + registry.register::(true); + registry.register::(true); + registry.register::(false); + registry.register::(true); + registry.register::(false); + registry.register::(false); + registry.register::(false); + registry.register::(true); + registry.register::(false); + registry.register::(true); + registry.register::(true); + registry.register::(true); + registry.register::(false); + registry.register::(false); + registry.register::(true); + registry.register::(false); + registry.register::(false); }); if let Err(e) = load_custom_providers_into_registry(&mut registry) { tracing::warn!("Failed to load custom providers: {}", e); @@ -260,6 +239,7 @@ mod tests { ("GOOSE_LEAD_FAILURE_THRESHOLD", failure_threshold), ("GOOSE_LEAD_FALLBACK_TURNS", fallback_turns), ("OPENAI_API_KEY", Some("fake-openai-no-keyring")), + ("OPENAI_CUSTOM_HEADERS", Some("")), ]); let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")) @@ -284,6 +264,7 @@ mod tests { ("GOOSE_LEAD_FAILURE_THRESHOLD", None), ("GOOSE_LEAD_FALLBACK_TURNS", None), ("OPENAI_API_KEY", Some("fake-openai-no-keyring")), + ("OPENAI_CUSTOM_HEADERS", Some("")), ]); let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")) diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 02a5784e85c4..2dbc965957b7 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -1,16 +1,21 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use async_trait::async_trait; use std::ops::Deref; use std::sync::Arc; use tokio::sync::Mutex; -use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ + LeadWorkerProviderTrait, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use super::errors::ProviderError; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; +use futures::future::BoxFuture; use rmcp::model::Tool; use rmcp::model::{Content, RawContent}; +const LEAD_WORKER_PROVIDER_NAME: &str = "lead_worker"; + /// A provider that switches between a lead model and a worker model based on turn count /// and can fallback to lead model on consecutive failures pub struct LeadWorkerProvider { @@ -314,12 +319,13 @@ impl LeadWorkerProviderTrait for LeadWorkerProvider { } } -#[async_trait] -impl Provider for LeadWorkerProvider { +impl ProviderDef for LeadWorkerProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { // This is a wrapper provider, so we return minimal metadata ProviderMetadata::new( - "lead_worker", + LEAD_WORKER_PROVIDER_NAME, "Lead/Worker Provider", "A provider that switches between lead and worker models based on turn count", "", // No default model as this is determined by the wrapped providers @@ -329,6 +335,13 @@ impl Provider for LeadWorkerProvider { ) } + fn from_env(_model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(async { Err(anyhow!("LeadWorkerProvider must be constructed explicitly")) }) + } +} + +#[async_trait] +impl Provider for LeadWorkerProvider { fn get_name(&self) -> &str { // Return the lead provider's name as the default self.lead_provider.get_name() @@ -486,7 +499,7 @@ impl Provider for LeadWorkerProvider { mod tests { use super::*; use crate::conversation::message::{Message, MessageContent}; - use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; + use crate::providers::base::{ProviderUsage, Usage}; use chrono::Utc; use rmcp::model::{AnnotateAble, RawTextContent, Role}; @@ -498,10 +511,6 @@ mod tests { #[async_trait] impl Provider for MockProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::empty() - } - fn get_name(&self) -> &str { "mock-lead" } @@ -684,10 +693,6 @@ mod tests { #[async_trait] impl Provider for MockFailureProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::empty() - } - fn get_name(&self) -> &str { "mock-lead" } diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index a27c0d289908..4d5af0d38d07 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -1,19 +1,21 @@ use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use serde_json::{json, Value}; use std::collections::HashMap; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; +use super::openai_compatible::handle_response_openai_compat; use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog}; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::conversation::message::Message; - use crate::model::ModelConfig; use rmcp::model::Tool; +const LITELLM_PROVIDER_NAME: &str = "litellm"; pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini"; pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/"; @@ -69,7 +71,7 @@ impl LiteLLMProvider { api_client, base_path, model, - name: Self::metadata().name, + name: LITELLM_PROVIDER_NAME.to_string(), }) } @@ -129,11 +131,12 @@ impl LiteLLMProvider { } } -#[async_trait] -impl Provider for LiteLLMProvider { +impl ProviderDef for LiteLLMProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "litellm", + LITELLM_PROVIDER_NAME, "LiteLLM", "LiteLLM proxy supporting multiple models with automatic prompt caching", LITELLM_DEFAULT_MODEL, @@ -154,6 +157,13 @@ impl Provider for LiteLLMProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for LiteLLMProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index abfc499d657b..4da0241a08f3 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -13,18 +13,19 @@ pub mod cursor_agent; pub mod databricks; pub mod embedding; pub mod errors; -mod factory; pub mod formats; mod gcpauth; pub mod gcpvertexai; pub mod gemini_cli; pub mod githubcopilot; pub mod google; +mod init; pub mod lead_worker; pub mod litellm; pub mod oauth; pub mod ollama; pub mod openai; +pub mod openai_compatible; pub mod openrouter; pub mod provider_registry; pub mod provider_test; @@ -39,7 +40,7 @@ pub mod utils; pub mod venice; pub mod xai; -pub use factory::{ +pub use init::{ create, create_with_default_model, create_with_named_model, providers, refresh_custom_providers, }; pub use retry::{retry_operation, RetryConfig}; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 02f091e5a63e..353e84b1dc1f 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,27 +1,30 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, - RequestLog, +use super::openai_compatible::{ + handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, }; +use super::retry::ProviderRetry; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::config::GooseMode; use crate::conversation::message::Message; use crate::conversation::Conversation; - use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use crate::utils::safe_truncate; use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use regex::Regex; use rmcp::model::Tool; use serde_json::Value; use std::time::Duration; use url::Url; +const OLLAMA_PROVIDER_NAME: &str = "ollama"; pub const OLLAMA_HOST: &str = "localhost"; pub const OLLAMA_TIMEOUT: u64 = 600; pub const OLLAMA_DEFAULT_PORT: u16 = 11434; @@ -78,7 +81,7 @@ impl OllamaProvider { api_client, model, supports_streaming: true, - name: Self::metadata().name, + name: OLLAMA_PROVIDER_NAME.to_string(), }) } @@ -142,11 +145,12 @@ impl OllamaProvider { } } -#[async_trait] -impl Provider for OllamaProvider { +impl ProviderDef for OllamaProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "ollama", + OLLAMA_PROVIDER_NAME, "Ollama", "Local open source models", OLLAMA_DEFAULT_MODEL, @@ -164,6 +168,13 @@ impl Provider for OllamaProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for OllamaProvider { fn get_name(&self) -> &str { &self.name } @@ -197,7 +208,7 @@ impl Provider for OllamaProvider { system, messages, filtered_tools, - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, false, )?; @@ -269,7 +280,7 @@ impl Provider for OllamaProvider { system, messages, filtered_tools, - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, true, )?; let mut log = RequestLog::start(&self.model, &payload)?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index cf50dad128b0..87edb4b6435f 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,5 +1,7 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; @@ -7,16 +9,17 @@ use super::formats::openai_responses::{ create_responses_request, get_responses_usage, responses_api_to_message, responses_api_to_streaming_message, ResponsesApiResponse, }; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, - ImageFormat, +use super::openai_compatible::{ + handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, }; +use super::retry::ProviderRetry; +use super::utils::{get_model, ImageFormat}; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; +use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use reqwest::StatusCode; use serde_json::Value; @@ -31,6 +34,7 @@ use crate::providers::base::MessageStream; use crate::providers::utils::RequestLog; use rmcp::model::Tool; +const OPEN_AI_PROVIDER_NAME: &str = "openai"; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; pub const OPEN_AI_DEFAULT_FAST_MODEL: &str = "gpt-4o-mini"; pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[ @@ -71,10 +75,13 @@ impl OpenAiProvider { .get_param("OPENAI_HOST") .unwrap_or_else(|_| "https://api.openai.com".to_string()); - let api_key: Option = config.get_secret("OPENAI_API_KEY").ok(); - let custom_headers: Option> = config - .get_secret::("OPENAI_CUSTOM_HEADERS") - .ok() + let secrets = config + .get_secrets("OPENAI_API_KEY", &["OPENAI_CUSTOM_HEADERS"]) + .unwrap_or_default(); + let api_key: Option = secrets.get("OPENAI_API_KEY").cloned(); + let custom_headers: Option> = secrets + .get("OPENAI_CUSTOM_HEADERS") + .cloned() .map(parse_custom_headers); let base_path: String = config @@ -117,7 +124,7 @@ impl OpenAiProvider { model, custom_headers, supports_streaming: true, - name: Self::metadata().name, + name: OPEN_AI_PROVIDER_NAME.to_string(), }) } @@ -131,7 +138,7 @@ impl OpenAiProvider { model, custom_headers: None, supports_streaming: true, - name: Self::metadata().name, + name: OPEN_AI_PROVIDER_NAME.to_string(), } } @@ -228,15 +235,16 @@ impl OpenAiProvider { } } -#[async_trait] -impl Provider for OpenAiProvider { +impl ProviderDef for OpenAiProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { let models = OPEN_AI_KNOWN_MODELS .iter() .map(|(name, limit)| ModelInfo::new(*name, *limit)) .collect(); ProviderMetadata::with_models( - "openai", + OPEN_AI_PROVIDER_NAME, "OpenAI", "GPT-4 and other OpenAI models, including OpenAI compatible ones", OPEN_AI_DEFAULT_MODEL, @@ -254,6 +262,13 @@ impl Provider for OpenAiProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for OpenAiProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs new file mode 100644 index 000000000000..daa1f685d08b --- /dev/null +++ b/crates/goose/src/providers/openai_compatible.rs @@ -0,0 +1,289 @@ +use anyhow::Error; +use async_stream::try_stream; +use futures::TryStreamExt; +use reqwest::{Response, StatusCode}; +use serde_json::Value; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; + +use super::api_client::ApiClient; +use super::base::{MessageStream, Provider, ProviderUsage, Usage}; +use super::errors::ProviderError; +use super::retry::ProviderRetry; +use super::utils::{get_model, ImageFormat, RequestLog}; +use crate::conversation::message::Message; +use crate::model::ModelConfig; +use crate::providers::formats::openai::{ + create_request, get_usage, response_to_message, response_to_streaming_message, +}; +use rmcp::model::Tool; + +pub struct OpenAiCompatibleProvider { + name: String, + /// Client targeted at the base URL (e.g. `https://api.x.ai/v1`) + api_client: ApiClient, + model: ModelConfig, +} + +impl OpenAiCompatibleProvider { + pub fn new(name: String, api_client: ApiClient, model: ModelConfig) -> Self { + Self { + name, + api_client, + model, + } + } + + fn build_request( + &self, + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + for_streaming: bool, + ) -> Result { + create_request( + model_config, + system, + messages, + tools, + &ImageFormat::OpenAi, + for_streaming, + ) + .map_err(|e| ProviderError::RequestFailed(format!("Failed to create request: {}", e))) + } +} + +#[async_trait::async_trait] +impl Provider for OpenAiCompatibleProvider { + fn get_name(&self) -> &str { + &self.name + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, model_config, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete_with_model( + &self, + session_id: Option<&str>, + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let payload = self.build_request(model_config, system, messages, tools, false)?; + let mut log = RequestLog::start(model_config, &payload)?; + + let response = self + .with_retry(|| async { + let resp = self + .api_client + .response_post(session_id, "chat/completions", &payload) + .await?; + handle_response_openai_compat(resp).await + }) + .await?; + + let response_model = get_model(&response); + let message = response_to_message(&response) + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); + log.write(&response, Some(&usage))?; + + Ok((message, ProviderUsage::new(response_model, usage))) + } + + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self + .api_client + .response_get(None, "models") + .await + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + let json = handle_response_openai_compat(response).await?; + + if let Some(err_obj) = json.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + let data = json.get("data").and_then(|v| v.as_array()); + match data { + Some(arr) => { + let mut models: Vec = arr + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) + .collect(); + models.sort(); + Ok(Some(models)) + } + None => Ok(None), + } + } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + session_id: &str, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let payload = self.build_request(&self.model, system, messages, tools, true)?; + let mut log = RequestLog::start(&self.model, &payload)?; + + let response = self + .with_retry(|| async { + let resp = self + .api_client + .response_post(Some(session_id), "chat/completions", &payload) + .await?; + handle_status_openai_compat(resp).await + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; + + stream_openai_compat(response, log) + } +} + +fn check_context_length_exceeded(text: &str) -> bool { + let check_phrases = [ + "too long", + "context length", + "context_length_exceeded", + "reduce the length", + "token count", + "exceeds", + "exceed context limit", + "input length", + "max_tokens", + "decrease input length", + "context limit", + "maximum prompt length", + ]; + let text_lower = text.to_lowercase(); + check_phrases + .iter() + .any(|phrase| text_lower.contains(phrase)) +} + +pub fn map_http_error_to_provider_error( + status: StatusCode, + payload: Option, +) -> ProviderError { + let extract_message = || -> String { + payload + .as_ref() + .and_then(|p| { + p.get("error") + .and_then(|e| e.get("message")) + .or_else(|| p.get("message")) + .and_then(|m| m.as_str()) + .map(String::from) + }) + .unwrap_or_else(|| payload.as_ref().map(|p| p.to_string()).unwrap_or_default()) + }; + + let error = match status { + StatusCode::OK => unreachable!("Should not call this function with OK status"), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderError::Authentication(format!( + "Authentication failed. Status: {}. Response: {}", + status, + extract_message() + )), + StatusCode::NOT_FOUND => { + ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message())) + } + StatusCode::PAYLOAD_TOO_LARGE => ProviderError::ContextLengthExceeded(extract_message()), + StatusCode::BAD_REQUEST => { + let payload_str = extract_message(); + if check_context_length_exceeded(&payload_str) { + ProviderError::ContextLengthExceeded(payload_str) + } else { + ProviderError::RequestFailed(format!("Bad request (400): {}", payload_str)) + } + } + StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded { + details: extract_message(), + retry_delay: None, + }, + _ if status.is_server_error() => { + ProviderError::ServerError(format!("Server error ({}): {}", status, extract_message())) + } + _ => ProviderError::RequestFailed(format!( + "Request failed with status {}: {}", + status, + extract_message() + )), + }; + + if !status.is_success() { + tracing::warn!( + "Provider request failed with status: {}. Payload: {:?}. Returning error: {:?}", + status, + payload, + error + ); + } + + error +} + +pub async fn handle_status_openai_compat(response: Response) -> Result { + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + let payload = serde_json::from_str::(&body).ok(); + return Err(map_http_error_to_provider_error(status, payload)); + } + Ok(response) +} + +pub async fn handle_response_openai_compat(response: Response) -> Result { + let response = handle_status_openai_compat(response).await?; + + response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e)) + }) +} + +pub fn stream_openai_compat( + response: Response, + mut log: RequestLog, +) -> Result { + let stream = response.bytes_stream().map_err(std::io::Error::other); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()) + .map_err(Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| + ProviderError::RequestFailed(format!("Stream decode error: {}", e)) + )?; + log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; + yield (message, usage); + } + })) +} diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 9d7b3523b406..0bdcd198b267 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,22 +1,25 @@ use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, - RequestLog, +use super::openai_compatible::{ + handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, }; +use super::retry::ProviderRetry; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::conversation::message::Message; - use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage}; use crate::providers::formats::openrouter as openrouter_format; use rmcp::model::Tool; +const OPENROUTER_PROVIDER_NAME: &str = "openrouter"; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-sonnet-4"; pub const OPENROUTER_DEFAULT_FAST_MODEL: &str = "google/gemini-2.5-flash"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; @@ -65,7 +68,7 @@ impl OpenRouterProvider { api_client, model, supports_streaming: true, - name: Self::metadata().name, + name: OPENROUTER_PROVIDER_NAME.to_string(), }) } @@ -203,7 +206,7 @@ async fn create_request_based_on_model( system, messages, tools, - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, false, )?; @@ -228,11 +231,12 @@ async fn create_request_based_on_model( Ok(payload) } -#[async_trait] -impl Provider for OpenRouterProvider { +impl ProviderDef for OpenRouterProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "openrouter", + OPENROUTER_PROVIDER_NAME, "OpenRouter", "Router for many model providers", OPENROUTER_DEFAULT_MODEL, @@ -251,6 +255,13 @@ impl Provider for OpenRouterProvider { .with_unlisted_models() } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for OpenRouterProvider { fn get_name(&self) -> &str { &self.name } @@ -401,7 +412,7 @@ impl Provider for OpenRouterProvider { system, messages, tools, - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, true, )?; diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index 4f2794d884ba..b5a25b0f6578 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -1,4 +1,4 @@ -use super::base::{ModelInfo, Provider, ProviderMetadata, ProviderType}; +use super::base::{ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType}; use crate::config::DeclarativeProviderConfig; use crate::model::ModelConfig; use anyhow::Result; @@ -36,22 +36,20 @@ impl ProviderRegistry { } } - pub fn register(&mut self, constructor: F, preferred: bool) + pub fn register(&mut self, preferred: bool) where - P: Provider + 'static, - F: Fn(ModelConfig) -> BoxFuture<'static, Result

> + Send + Sync + 'static, + F: ProviderDef + 'static, { - let metadata = P::metadata(); + let metadata = F::metadata(); let name = metadata.name.clone(); self.entries.insert( name, ProviderEntry { metadata, - constructor: Arc::new(move |model| { - let fut = constructor(model); + constructor: Arc::new(|model| { Box::pin(async move { - let provider = fut.await?; + let provider = F::from_env(model).await?; Ok(Arc::new(provider) as Arc) }) }), @@ -70,8 +68,8 @@ impl ProviderRegistry { provider_type: ProviderType, constructor: F, ) where - P: Provider + 'static, - F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, + P: ProviderDef + 'static, + F: Fn(ModelConfig) -> Result + Send + Sync + 'static, { let base_metadata = P::metadata(); let description = config diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 0517836ef117..5981d5f2fc41 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -9,7 +9,7 @@ use aws_sdk_sagemakerruntime::Client as SageMakerClient; use rmcp::model::Tool; use serde_json::{json, Value}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::RequestLog; @@ -18,8 +18,10 @@ use crate::session_context::SESSION_ID_HEADER; use crate::model::ModelConfig; use chrono::Utc; +use futures::future::BoxFuture; use rmcp::model::Role; +const SAGEMAKER_TGI_PROVIDER_NAME: &str = "sagemaker_tgi"; pub const SAGEMAKER_TGI_DOC_LINK: &str = "https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html"; @@ -82,7 +84,7 @@ impl SageMakerTgiProvider { sagemaker_client, endpoint_name, model, - name: Self::metadata().name, + name: SAGEMAKER_TGI_PROVIDER_NAME.to_string(), }) } @@ -268,11 +270,12 @@ impl SageMakerTgiProvider { } } -#[async_trait] -impl Provider for SageMakerTgiProvider { +impl ProviderDef for SageMakerTgiProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "sagemaker_tgi", + SAGEMAKER_TGI_PROVIDER_NAME, "Amazon SageMaker TGI", "Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.", SAGEMAKER_TGI_DEFAULT_MODEL, @@ -286,6 +289,13 @@ impl Provider for SageMakerTgiProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for SageMakerTgiProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 88db4cd213d3..f0fcb01e262f 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -4,17 +4,20 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::snowflake::{create_request, get_usage, response_to_message}; +use super::openai_compatible::map_http_error_to_provider_error; use super::retry::ProviderRetry; -use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat, RequestLog}; +use super::utils::{get_model, ImageFormat, RequestLog}; use crate::config::ConfigError; use crate::conversation::message::Message; use crate::model::ModelConfig; +use futures::future::BoxFuture; use rmcp::model::Tool; +const SNOWFLAKE_PROVIDER_NAME: &str = "snowflake"; pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-sonnet-4-5"; pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &[ // Claude 4.5 series @@ -103,7 +106,7 @@ impl SnowflakeProvider { api_client, model, image_format: ImageFormat::OpenAi, - name: Self::metadata().name, + name: SNOWFLAKE_PROVIDER_NAME.to_string(), }) } @@ -292,11 +295,12 @@ impl SnowflakeProvider { } } -#[async_trait] -impl Provider for SnowflakeProvider { +impl ProviderDef for SnowflakeProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "snowflake", + SNOWFLAKE_PROVIDER_NAME, "Snowflake", "Access the latest models using Snowflake Cortex services.", SNOWFLAKE_DEFAULT_MODEL, @@ -309,6 +313,13 @@ impl Provider for SnowflakeProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for SnowflakeProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index fa249ea0ec9b..942c0a2b5bd0 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; @@ -7,10 +7,11 @@ use std::fs; use std::path::Path; use std::sync::{Arc, Mutex}; -use super::base::{Provider, ProviderMetadata, ProviderUsage}; +use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use crate::conversation::message::Message; use crate::model::ModelConfig; +use futures::future::BoxFuture; use rmcp::model::Tool; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -40,12 +41,14 @@ pub struct TestProvider { } impl TestProvider { + const PROVIDER_NAME: &str = "test"; + pub fn new_recording(inner: Arc, file_path: impl Into) -> Self { Self { inner: Some(inner), records: Arc::new(Mutex::new(HashMap::new())), file_path: file_path.into(), - name: Self::metadata().name, + name: Self::PROVIDER_NAME.to_string(), } } @@ -57,7 +60,7 @@ impl TestProvider { inner: None, records: Arc::new(Mutex::new(records)), file_path, - name: Self::metadata().name, + name: Self::PROVIDER_NAME.to_string(), }) } @@ -101,11 +104,12 @@ impl TestProvider { } } -#[async_trait] -impl Provider for TestProvider { +impl ProviderDef for TestProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "test", + Self::PROVIDER_NAME, "Test Provider", "Provider for testing that can record/replay interactions", "test-model", @@ -115,6 +119,13 @@ impl Provider for TestProvider { ) } + fn from_env(_model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(async { Err(anyhow!("TestProvider must be constructed explicitly")) }) + } +} + +#[async_trait] +impl Provider for TestProvider { fn get_name(&self) -> &str { &self.name } @@ -188,18 +199,6 @@ mod tests { #[async_trait] impl Provider for MockProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::new( - "mock", - "Mock Provider", - "Mock provider for testing", - "mock-model", - vec!["mock-model"], - "", - vec![], - ) - } - fn get_name(&self) -> &str { "mock-testprovider" } diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index a79932d9a216..ba31a28cfdcd 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,21 +1,25 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_google_compat, handle_response_openai_compat, - handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog, +use super::openai_compatible::{ + handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, }; +use super::retry::ProviderRetry; +use super::utils::{get_model, handle_response_google_compat, is_google_model, RequestLog}; use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL; use crate::conversation::message::Message; use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use serde_json::Value; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use rmcp::model::Tool; +const TETRATE_PROVIDER_NAME: &str = "tetrate"; // Tetrate Agent Router Service can run many models, we suggest the default pub const TETRATE_KNOWN_MODELS: &[&str] = &[ "claude-opus-4-1", @@ -59,7 +63,7 @@ impl TetrateProvider { api_client, model, supports_streaming: true, - name: Self::metadata().name, + name: TETRATE_PROVIDER_NAME.to_string(), }) } @@ -126,11 +130,12 @@ impl TetrateProvider { } } -#[async_trait] -impl Provider for TetrateProvider { +impl ProviderDef for TetrateProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "tetrate", + TETRATE_PROVIDER_NAME, "Tetrate Agent Router Service", "Enterprise router for AI models", TETRATE_DEFAULT_MODEL, @@ -148,6 +153,13 @@ impl Provider for TetrateProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for TetrateProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index b3cd68a87005..2a8360c673a8 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,13 +1,10 @@ -use super::base::{MessageStream, Usage}; +use super::base::Usage; use super::errors::GoogleErrorCode; use crate::config::paths::Paths; use crate::model::ModelConfig; use crate::providers::errors::ProviderError; -use crate::providers::formats::openai::response_to_streaming_message; use anyhow::{anyhow, Result}; -use async_stream::try_stream; use base64::Engine; -use futures::TryStreamExt; use regex::Regex; use reqwest::{Response, StatusCode}; use rmcp::model::{AnnotateAble, ImageContent, RawImageContent}; @@ -15,15 +12,10 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::fmt::Display; use std::fs::File; -use std::io; use std::io::{BufWriter, Read, Write}; use std::path::{Path, PathBuf}; use std::sync::OnceLock; use std::time::Duration; -use tokio::pin; -use tokio_stream::StreamExt; -use tokio_util::codec::{FramedRead, LinesCodec}; -use tokio_util::io::StreamReader; use uuid::Uuid; #[derive(Debug, Copy, Clone, Serialize, Deserialize)] @@ -77,27 +69,6 @@ pub fn filter_extensions_from_system_prompt(system: &str) -> String { } } -fn check_context_length_exceeded(text: &str) -> bool { - let check_phrases = [ - "too long", - "context length", - "context_length_exceeded", - "reduce the length", - "token count", - "exceeds", - "exceed context limit", - "input length", - "max_tokens", - "decrease input length", - "context limit", - "maximum prompt length", - ]; - let text_lower = text.to_lowercase(); - check_phrases - .iter() - .any(|phrase| text_lower.contains(phrase)) -} - fn format_server_error_message(status_code: StatusCode, payload: Option<&Value>) -> String { match payload { Some(Value::Null) | None => format!( @@ -108,109 +79,6 @@ fn format_server_error_message(status_code: StatusCode, payload: Option<&Value>) } } -pub fn map_http_error_to_provider_error( - status: StatusCode, - payload: Option, -) -> ProviderError { - let extract_message = || -> String { - payload - .as_ref() - .and_then(|p| { - p.get("error") - .and_then(|e| e.get("message")) - .or_else(|| p.get("message")) - .and_then(|m| m.as_str()) - .map(String::from) - }) - .unwrap_or_else(|| payload.as_ref().map(|p| p.to_string()).unwrap_or_default()) - }; - - let error = match status { - StatusCode::OK => unreachable!("Should not call this function with OK status"), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderError::Authentication(format!( - "Authentication failed. Status: {}. Response: {}", - status, - extract_message() - )), - StatusCode::NOT_FOUND => { - ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message())) - } - StatusCode::PAYLOAD_TOO_LARGE => ProviderError::ContextLengthExceeded(extract_message()), - StatusCode::BAD_REQUEST => { - let payload_str = extract_message(); - if check_context_length_exceeded(&payload_str) { - ProviderError::ContextLengthExceeded(payload_str) - } else { - ProviderError::RequestFailed(format!("Bad request (400): {}", payload_str)) - } - } - StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded { - details: extract_message(), - retry_delay: None, - }, - _ if status.is_server_error() => { - ProviderError::ServerError(format!("Server error ({}): {}", status, extract_message())) - } - _ => ProviderError::RequestFailed(format!( - "Request failed with status {}: {}", - status, - extract_message() - )), - }; - - if !status.is_success() { - tracing::warn!( - "Provider request failed with status: {}. Payload: {:?}. Returning error: {:?}", - status, - payload, - error - ); - } - - error -} - -pub async fn handle_status_openai_compat(response: Response) -> Result { - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - let payload = serde_json::from_str::(&body).ok(); - return Err(map_http_error_to_provider_error(status, payload)); - } - Ok(response) -} - -pub async fn handle_response_openai_compat(response: Response) -> Result { - let response = handle_status_openai_compat(response).await?; - - response.json::().await.map_err(|e| { - ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e)) - }) -} - -pub fn stream_openai_compat( - response: Response, - mut log: RequestLog, -) -> Result { - let stream = response.bytes_stream().map_err(io::Error::other); - - Ok(Box::pin(try_stream! { - let stream_reader = StreamReader::new(stream); - let framed = FramedRead::new(stream_reader, LinesCodec::new()) - .map_err(anyhow::Error::from); - - let message_stream = response_to_streaming_message(framed); - pin!(message_stream); - while let Some(message) = message_stream.next().await { - let (message, usage) = message.map_err(|e| - ProviderError::RequestFailed(format!("Stream decode error: {}", e)) - )?; - log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; - yield (message, usage); - } - })) -} - pub fn is_google_model(payload: &Value) -> bool { payload .get("model") diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 935590d6da1c..e42b1cb00715 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -5,14 +5,15 @@ use serde::Serialize; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::openai_compatible::map_http_error_to_provider_error; use super::retry::ProviderRetry; -use super::utils::map_http_error_to_provider_error; use crate::conversation::message::{Message, MessageContent}; use crate::mcp_utils::ToolResult; use crate::model::ModelConfig; +use futures::future::BoxFuture; use rmcp::model::{object, CallToolRequestParams, Role, Tool}; // ---------- Capability Flags ---------- @@ -58,6 +59,7 @@ fn strip_flags(model: &str) -> &str { } // ---------- END Helpers ---------- +const VENICE_PROVIDER_NAME: &str = "venice"; pub const VENICE_DOC_URL: &str = "https://docs.venice.ai/"; pub const VENICE_DEFAULT_MODEL: &str = "llama-3.3-70b"; pub const VENICE_DEFAULT_HOST: &str = "https://api.venice.ai"; @@ -107,7 +109,7 @@ impl VeniceProvider { base_path, models_path, model, - name: Self::metadata().name, + name: VENICE_PROVIDER_NAME.to_string(), }; Ok(instance) @@ -192,11 +194,12 @@ impl VeniceProvider { } } -#[async_trait] -impl Provider for VeniceProvider { +impl ProviderDef for VeniceProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "venice", + VENICE_PROVIDER_NAME, "Venice.ai", "Venice.ai models (Llama, DeepSeek, Mistral) with function calling", VENICE_DEFAULT_MODEL, @@ -221,6 +224,13 @@ impl Provider for VeniceProvider { ) } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(Self::from_env(model)) + } +} + +#[async_trait] +impl Provider for VeniceProvider { fn get_name(&self) -> &str { &self.name } diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 908bc73fbe4b..c3cfcc815205 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -1,20 +1,11 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::errors::ProviderError; -use super::retry::ProviderRetry; -use super::utils::{ - get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, - RequestLog, -}; -use crate::conversation::message::Message; +use super::base::{ConfigKey, ProviderDef, ProviderMetadata}; +use super::openai_compatible::OpenAiCompatibleProvider; use crate::model::ModelConfig; -use crate::providers::base::{ - ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage, -}; -use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use anyhow::Result; -use async_trait::async_trait; -use rmcp::model::Tool; -use serde_json::Value; +use futures::future::BoxFuture; + +const XAI_PROVIDER_NAME: &str = "xai"; pub const XAI_API_HOST: &str = "https://api.x.ai/v1"; pub const XAI_DEFAULT_MODEL: &str = "grok-code-fast-1"; pub const XAI_KNOWN_MODELS: &[&str] = &[ @@ -40,50 +31,14 @@ pub const XAI_KNOWN_MODELS: &[&str] = &[ pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview"; -#[derive(serde::Serialize)] -pub struct XaiProvider { - #[serde(skip)] - api_client: ApiClient, - model: ModelConfig, - supports_streaming: bool, - #[serde(skip)] - name: String, -} - -impl XaiProvider { - pub async fn from_env(model: ModelConfig) -> Result { - let config = crate::config::Config::global(); - let api_key: String = config.get_secret("XAI_API_KEY")?; - let host: String = config - .get_param("XAI_HOST") - .unwrap_or_else(|_| XAI_API_HOST.to_string()); - - let auth = AuthMethod::BearerToken(api_key); - let api_client = ApiClient::new(host, auth)?; - - Ok(Self { - api_client, - model, - supports_streaming: true, - name: Self::metadata().name, - }) - } - - async fn post(&self, session_id: Option<&str>, payload: Value) -> Result { - let response = self - .api_client - .response_post(session_id, "chat/completions", &payload) - .await?; +pub struct XaiProvider; - handle_response_openai_compat(response).await - } -} +impl ProviderDef for XaiProvider { + type Provider = OpenAiCompatibleProvider; -#[async_trait] -impl Provider for XaiProvider { fn metadata() -> ProviderMetadata { ProviderMetadata::new( - "xai", + XAI_PROVIDER_NAME, "xAI", "Grok models from xAI, including reasoning and multimodal capabilities", XAI_DEFAULT_MODEL, @@ -96,114 +51,21 @@ impl Provider for XaiProvider { ) } - fn get_name(&self) -> &str { - &self.name - } - - fn get_model_config(&self) -> ModelConfig { - self.model.clone() - } - - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request( - model_config, - system, - messages, - tools, - &super::utils::ImageFormat::OpenAi, - false, - )?; - - let mut log = RequestLog::start(&self.model, &payload)?; - let response = self - .with_retry(|| self.post(session_id, payload.clone())) - .await?; - - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) - } - - fn supports_streaming(&self) -> bool { - self.supports_streaming - } - - async fn stream( - &self, - session_id: &str, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result { - let payload = create_request( - &self.model, - system, - messages, - tools, - &super::utils::ImageFormat::OpenAi, - true, - )?; - let mut log = RequestLog::start(&self.model, &payload)?; - - let response = self - .with_retry(|| async { - let resp = self - .api_client - .response_post(Some(session_id), "chat/completions", &payload) - .await?; - handle_status_openai_compat(resp).await - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - - stream_openai_compat(response, log) - } - - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - let response = self.api_client.response_get(None, "models").await?; - let json = handle_response_openai_compat(response).await?; - - if let Some(err_obj) = json.get("error") { - let msg = err_obj - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - return Err(ProviderError::Authentication(msg.to_string())); - } - - let data = json.get("data").and_then(|v| v.as_array()); - match data { - Some(arr) => { - let mut models: Vec = arr - .iter() - .filter_map(|m| { - m.get("id") - .and_then(|id| id.as_str()) - .map(|s| s.to_string()) - }) - .collect(); - models.sort(); - Ok(Some(models)) - } - None => Ok(None), - } + fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { + Box::pin(async move { + let config = crate::config::Config::global(); + let api_key: String = config.get_secret("XAI_API_KEY")?; + let host: String = config + .get_param("XAI_HOST") + .unwrap_or_else(|_| XAI_API_HOST.to_string()); + + let api_client = ApiClient::new(host, AuthMethod::BearerToken(api_key))?; + + Ok(OpenAiCompatibleProvider::new( + XAI_PROVIDER_NAME.to_string(), + api_client, + model, + )) + }) } } diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 4f353e6c2fff..ed37fa02711d 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -336,7 +336,9 @@ mod tests { use goose::agents::SessionConfig; use goose::conversation::message::{Message, MessageContent}; use goose::model::ModelConfig; - use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; + use goose::providers::base::{ + Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, + }; use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; use rmcp::model::{CallToolRequestParams, Tool}; @@ -351,6 +353,29 @@ mod tests { } } + impl ProviderDef for MockToolProvider { + type Provider = Self; + + fn metadata() -> ProviderMetadata { + ProviderMetadata { + name: "mock".to_string(), + display_name: "Mock Provider".to_string(), + description: "Mock provider for testing".to_string(), + default_model: "mock-model".to_string(), + known_models: vec![], + model_doc_link: "".to_string(), + config_keys: vec![], + allows_unlisted_models: false, + } + } + + fn from_env( + _model: ModelConfig, + ) -> futures::future::BoxFuture<'static, anyhow::Result> { + Box::pin(async { Ok(Self::new()) }) + } + } + #[async_trait] impl Provider for MockToolProvider { async fn complete( @@ -394,19 +419,6 @@ mod tests { ModelConfig::new("mock-model").unwrap() } - fn metadata() -> ProviderMetadata { - ProviderMetadata { - name: "mock".to_string(), - display_name: "Mock Provider".to_string(), - description: "Mock provider for testing".to_string(), - default_model: "mock-model".to_string(), - known_models: vec![], - model_doc_link: "".to_string(), - config_keys: vec![], - allows_unlisted_models: false, - } - } - fn get_name(&self) -> &str { "mock-test" } diff --git a/crates/goose/tests/compaction.rs b/crates/goose/tests/compaction.rs index 59d7f923a40d..b922912432a5 100644 --- a/crates/goose/tests/compaction.rs +++ b/crates/goose/tests/compaction.rs @@ -5,7 +5,7 @@ use goose::agents::{Agent, AgentEvent, SessionConfig}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; use goose::model::ModelConfig; -use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; +use goose::providers::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; use goose::session::Session; @@ -170,6 +170,14 @@ impl Provider for MockCompactionProvider { ModelConfig::new("mock-model").unwrap() } + fn get_name(&self) -> &str { + "mock-compaction" + } +} + +impl ProviderDef for MockCompactionProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata { name: "mock".to_string(), @@ -183,8 +191,8 @@ impl Provider for MockCompactionProvider { } } - fn get_name(&self) -> &str { - "mock-compaction" + fn from_env(_model: ModelConfig) -> futures::future::BoxFuture<'static, anyhow::Result> { + Box::pin(async { Ok(Self::new()) }) } } diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 635c8d66023d..22a279305e8a 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -18,7 +18,7 @@ use test_case::test_case; use async_trait::async_trait; use goose::conversation::message::Message; -use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; +use goose::providers::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; use once_cell::sync::Lazy; use std::process::Command; @@ -47,12 +47,20 @@ impl MockProvider { } } -#[async_trait] -impl Provider for MockProvider { +impl ProviderDef for MockProvider { + type Provider = Self; + fn metadata() -> ProviderMetadata { ProviderMetadata::empty() } + fn from_env(model: ModelConfig) -> futures::future::BoxFuture<'static, anyhow::Result> { + Box::pin(async move { Ok(Self::new(model)) }) + } +} + +#[async_trait] +impl Provider for MockProvider { fn get_name(&self) -> &str { "mock" }