Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clippy-baselines/too_many_lines.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ crates/goose/src/agents/agent.rs::create_recipe
crates/goose/src/agents/agent.rs::dispatch_tool_call
crates/goose/src/agents/agent.rs::reply
crates/goose/src/agents/agent.rs::reply_internal
crates/goose/src/providers/canonical/build_canonical_models.rs::build_canonical_models
crates/goose/src/providers/formats/anthropic.rs::format_messages
crates/goose/src/providers/formats/anthropic.rs::response_to_streaming_message
crates/goose/src/providers/formats/databricks.rs::format_messages
Expand Down
3 changes: 3 additions & 0 deletions crates/goose-acp/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![recursion_limit = "256"]
#![allow(unused_attributes)]

use assert_json_diff::{assert_json_matches_no_panic, CompareMode, Config};
use async_trait::async_trait;
use fs_err as fs;
Expand Down
4 changes: 0 additions & 4 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,6 @@ mod tests {

#[async_trait]
impl Provider for MockProvider {
fn metadata() -> crate::providers::base::ProviderMetadata {
crate::providers::base::ProviderMetadata::empty()
}

fn get_name(&self) -> &str {
"mock"
}
Expand Down
9 changes: 1 addition & 8 deletions crates/goose/src/context_mgmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,7 @@ mod tests {
use super::*;
use crate::{
model::ModelConfig,
providers::{
base::{ProviderMetadata, Usage},
errors::ProviderError,
},
providers::{base::Usage, errors::ProviderError},
};
use async_trait::async_trait;
use rmcp::model::{AnnotateAble, CallToolRequestParams, RawContent, Tool};
Expand Down Expand Up @@ -577,10 +574,6 @@ mod tests {

#[async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new("mock", "", "", "", vec![""], "", vec![])
}

fn get_name(&self) -> &str {
"mock"
}
Expand Down
26 changes: 20 additions & 6 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@ use tokio::pin;
use tokio_util::io::StreamReader;

use super::api_client::{ApiClient, ApiResponse, AuthMethod};
use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use super::base::{
ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage,
};
use super::errors::ProviderError;
use super::formats::anthropic::{
create_request, get_usage, response_to_message, response_to_streaming_message,
};
use super::utils::{get_model, handle_status_openai_compat, map_http_error_to_provider_error};
use super::openai_compatible::handle_status_openai_compat;
use super::openai_compatible::map_http_error_to_provider_error;
use super::utils::get_model;
use crate::config::declarative_providers::DeclarativeProviderConfig;
use crate::conversation::message::Message;
use crate::model::ModelConfig;
use crate::providers::retry::ProviderRetry;
use crate::providers::utils::RequestLog;
use futures::future::BoxFuture;
use rmcp::model::Tool;

const ANTHROPIC_PROVIDER_NAME: &str = "anthropic";
pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-5";
const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-haiku-4-5";
const ANTHROPIC_KNOWN_MODELS: &[&str] = &[
Expand Down Expand Up @@ -73,7 +79,7 @@ impl AnthropicProvider {
api_client,
model,
supports_streaming: true,
name: Self::metadata().name,
name: ANTHROPIC_PROVIDER_NAME.to_string(),
})
}

Expand Down Expand Up @@ -171,16 +177,17 @@ impl AnthropicProvider {
}
}

#[async_trait]
impl Provider for AnthropicProvider {
impl ProviderDef for AnthropicProvider {
type Provider = Self;

fn metadata() -> ProviderMetadata {
let models: Vec<ModelInfo> = ANTHROPIC_KNOWN_MODELS
.iter()
.map(|&model_name| ModelInfo::new(model_name, 200_000))
.collect();

ProviderMetadata::with_models(
"anthropic",
ANTHROPIC_PROVIDER_NAME,
"Anthropic",
"Claude and other models from Anthropic",
ANTHROPIC_DEFAULT_MODEL,
Expand All @@ -198,6 +205,13 @@ impl Provider for AnthropicProvider {
)
}

fn from_env(model: ModelConfig) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
}
}

#[async_trait]
impl Provider for AnthropicProvider {
fn get_name(&self) -> &str {
&self.name
}
Expand Down
17 changes: 15 additions & 2 deletions crates/goose/src/providers/api_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct ApiClient {
host: String,
auth: AuthMethod,
default_headers: HeaderMap,
default_query: Vec<(String, String)>,
timeout: Duration,
tls_config: Option<TlsConfig>,
}
Expand Down Expand Up @@ -222,6 +223,7 @@ impl ApiClient {
host,
auth,
default_headers: HeaderMap::new(),
default_query: Vec::new(),
timeout,
tls_config,
})
Expand Down Expand Up @@ -267,6 +269,11 @@ impl ApiClient {
Ok(self)
}

pub fn with_query(mut self, params: Vec<(String, String)>) -> Self {
self.default_query = params;
self
}

pub fn with_header(mut self, key: &str, value: &str) -> Result<Self> {
let header_name = HeaderName::from_bytes(key.as_bytes())?;
let header_value = HeaderValue::from_str(value)?;
Expand Down Expand Up @@ -325,9 +332,15 @@ impl ApiClient {
base_url.set_path(&format!("{}/", base_path));
}

base_url
let mut url = base_url
.join(path)
.map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e))
.map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e))?;

for (key, value) in &self.default_query {
url.query_pairs_mut().append_pair(key, value);
}

Ok(url)
}

async fn get_oauth_token(&self, config: &OAuthConfig) -> Result<String> {
Expand Down
167 changes: 41 additions & 126 deletions crates/goose/src/providers/azure.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,21 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::Serialize;
use serde_json::Value;

use super::api_client::{ApiClient, AuthMethod, AuthProvider};
use super::azureauth::{AuthError, AzureAuth};
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::retry::ProviderRetry;
use super::utils::{get_model, handle_response_openai_compat, ImageFormat};
use crate::conversation::message::Message;
use super::base::{ConfigKey, ProviderDef, ProviderMetadata};
use super::openai_compatible::OpenAiCompatibleProvider;
use crate::model::ModelConfig;
use crate::providers::utils::RequestLog;
use rmcp::model::Tool;
use futures::future::BoxFuture;

const AZURE_PROVIDER_NAME: &str = "azure_openai";
pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
pub const AZURE_DOC_URL: &str =
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];

#[derive(Debug)]
pub struct AzureProvider {
api_client: ApiClient,
deployment_name: String,
api_version: String,
model: ModelConfig,
name: String,
}

impl Serialize for AzureProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("AzureProvider", 2)?;
state.serialize_field("deployment_name", &self.deployment_name)?;
state.serialize_field("api_version", &self.api_version)?;
state.end()
}
}
pub struct AzureProvider;

// Custom auth provider that wraps AzureAuth
struct AzureAuthProvider {
Expand Down Expand Up @@ -69,60 +43,12 @@ impl AuthProvider for AzureAuthProvider {
}
}

impl AzureProvider {
pub async fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
let api_version: String = config
.get_param("AZURE_OPENAI_API_VERSION")
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());

let api_key = config
.get_secret("AZURE_OPENAI_API_KEY")
.ok()
.filter(|key: &String| !key.is_empty());
let auth = AzureAuth::new(api_key).map_err(|e| match e {
AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg),
AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg),
})?;

let auth_provider = AzureAuthProvider { auth };
let api_client = ApiClient::new(endpoint, AuthMethod::Custom(Box::new(auth_provider)))?;

Ok(Self {
api_client,
deployment_name,
api_version,
model,
name: Self::metadata().name,
})
}

async fn post(
&self,
session_id: Option<&str>,
payload: &Value,
) -> Result<Value, ProviderError> {
// Build the path for Azure OpenAI
let path = format!(
"openai/deployments/{}/chat/completions?api-version={}",
self.deployment_name, self.api_version
);
impl ProviderDef for AzureProvider {
type Provider = OpenAiCompatibleProvider;

let response = self
.api_client
.response_post(session_id, &path, payload)
.await?;
handle_response_openai_compat(response).await
}
}

#[async_trait]
impl Provider for AzureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"azure_openai",
AZURE_PROVIDER_NAME,
"Azure OpenAI",
"Models through Azure OpenAI Service (uses Azure credential chain by default)",
"gpt-4o",
Expand All @@ -137,49 +63,38 @@ impl Provider for AzureProvider {
)
}

fn get_name(&self) -> &str {
&self.name
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}

#[tracing::instrument(
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete_with_model(
&self,
session_id: Option<&str>,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(
model_config,
system,
messages,
tools,
&ImageFormat::OpenAi,
false,
)?;
let response = self
.with_retry(|| async {
let payload_clone = payload.clone();
self.post(session_id, &payload_clone).await
})
.await?;

let message = response_to_message(&response)?;
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
tracing::debug!("Failed to get usage data");
Usage::default()
});
let response_model = get_model(&response);
let mut log = RequestLog::start(model_config, &payload)?;
log.write(&response, Some(&usage))?;
Ok((message, ProviderUsage::new(response_model, usage)))
fn from_env(model: ModelConfig) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(async move {
let config = crate::config::Config::global();
let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
let api_version: String = config
.get_param("AZURE_OPENAI_API_VERSION")
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());

let api_key = config
.get_secret("AZURE_OPENAI_API_KEY")
.ok()
.filter(|key: &String| !key.is_empty());
let auth = AzureAuth::new(api_key).map_err(|e| match e {
AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg),
AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg),
})?;

let auth_provider = AzureAuthProvider { auth };
let host = format!(
"{}/openai/deployments/{}",
endpoint.trim_end_matches('/'),
deployment_name
);
let api_client = ApiClient::new(host, AuthMethod::Custom(Box::new(auth_provider)))?
.with_query(vec![("api-version".to_string(), api_version)]);

Ok(OpenAiCompatibleProvider::new(
AZURE_PROVIDER_NAME.to_string(),
api_client,
model,
))
})
}
}
Loading
Loading