diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 1308513ccf40..99298f433f43 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -5,8 +5,13 @@ use cliclack::spinner; use console::style; use goose::key_manager::{get_keyring_secret, save_to_keyring, KeyRetrievalStrategy}; use goose::message::Message; +use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::factory; +use goose::providers::google::GOOGLE_DEFAULT_MODEL; +use goose::providers::groq::GROQ_DEFAULT_MODEL; use goose::providers::ollama::OLLAMA_MODEL; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use std::error::Error; pub async fn handle_configure( @@ -48,6 +53,7 @@ pub async fn handle_configure( ("ollama", "Ollama", "Local open source models"), ("anthropic", "Anthropic", "Claude models"), ("google", "Google Gemini", "Gemini models"), + ("groq", "Groq", "AI models"), ]) .interact()? .to_string() @@ -154,11 +160,12 @@ pub async fn handle_configure( pub fn get_recommended_model(provider_name: &str) -> &str { match provider_name { - "openai" => "gpt-4o", - "databricks" => "claude-3-5-sonnet-2", + "openai" => OPEN_AI_DEFAULT_MODEL, + "databricks" => DATABRICKS_DEFAULT_MODEL, "ollama" => OLLAMA_MODEL, - "anthropic" => "claude-3-5-sonnet-2", - "google" => "gemini-1.5-flash", + "anthropic" => ANTHROPIC_DEFAULT_MODEL, + "google" => GOOGLE_DEFAULT_MODEL, + "groq" => GROQ_DEFAULT_MODEL, _ => panic!("Invalid provider name"), } } @@ -170,6 +177,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> { "ollama" => vec!["OLLAMA_HOST"], "anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint "google" => vec!["GOOGLE_API_KEY"], + "groq" => vec!["GROQ_API_KEY"], _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index 6e03f6b387cc..d62dfc5b868e 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -2,7 +2,7 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; use goose::providers::configs::{ AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, - ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, + GroqProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -130,7 +130,17 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon .expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); ProviderConfig::Google(GoogleProviderConfig { - host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint + host: "https://generativelanguage.googleapis.com".to_string(), + api_key, + model: model_config, + }) + } + "groq" => { + let api_key = get_keyring_secret("GROQ_API_KEY", KeyRetrievalStrategy::Both) + .expect("GROQ_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); + + ProviderConfig::Groq(GroqProviderConfig { + host: "https://api.groq.com".to_string(), api_key, model: model_config, }) diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index de47633013a1..c6435db591f2 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -1,13 +1,14 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; -use goose::providers::configs::GoogleProviderConfig; +use goose::providers::configs::{GoogleProviderConfig, GroqProviderConfig}; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use goose::providers::{ configs::{ DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, - google, ollama, + google, groq, ollama, utils::ImageFormat, }; use serde::Deserialize; @@ -88,6 +89,17 @@ pub enum ProviderSettings { #[serde(default)] max_tokens: Option, }, + Groq { + #[serde(default = "default_groq_host")] + host: String, + api_key: String, + #[serde(default = "default_groq_model")] + model: String, + #[serde(default)] + temperature: Option, + #[serde(default)] + max_tokens: Option, + }, } impl ProviderSettings { @@ -99,6 +111,7 @@ impl ProviderSettings { ProviderSettings::Databricks { .. } => ProviderType::Databricks, ProviderSettings::Ollama { .. } => ProviderType::Ollama, ProviderSettings::Google { .. } => ProviderType::Google, + ProviderSettings::Groq { .. } => ProviderType::Groq, } } @@ -168,6 +181,19 @@ impl ProviderSettings { .with_temperature(temperature) .with_max_tokens(max_tokens), }), + ProviderSettings::Groq { + host, + api_key, + model, + temperature, + max_tokens, + } => ProviderConfig::Groq(GroqProviderConfig { + host, + api_key, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens), + }), } } } @@ -240,7 +266,7 @@ fn default_port() -> u16 { } fn default_model() -> String { - "gpt-4o".to_string() + OPEN_AI_DEFAULT_MODEL.to_string() } fn default_openai_host() -> String { @@ -267,6 +293,14 @@ fn default_google_model() -> String { google::GOOGLE_DEFAULT_MODEL.to_string() } +fn default_groq_host() -> String { + groq::GROQ_API_HOST.to_string() +} + +fn default_groq_model() -> String { + groq::GROQ_DEFAULT_MODEL.to_string() +} + fn default_image_format() -> ImageFormat { ImageFormat::Anthropic } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 446c538dcee4..8c07f82547bd 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use goose::providers::configs::GroqProviderConfig; use goose::{ agent::Agent, developer::DeveloperSystem, @@ -71,6 +72,11 @@ impl Clone for AppState { model: config.model.clone(), }) } + ProviderConfig::Groq(config) => ProviderConfig::Groq(GroqProviderConfig { + host: config.host.clone(), + api_key: config.api_key.clone(), + model: config.model.clone(), + }), }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), diff --git a/crates/goose/build.rs b/crates/goose/build.rs index 421ad3df5821..ccfb369848bb 100644 --- a/crates/goose/build.rs +++ b/crates/goose/build.rs @@ -8,6 +8,7 @@ const MODELS: &[&str] = &[ "Xenova/gemma-2-tokenizer", "Xenova/gpt-4o", "Qwen/Qwen2.5-Coder-32B-Instruct", + "Xenova/llama3-tokenizer", ]; #[tokio::main] diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index f2d7758aec67..6f2fb5b9152f 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -7,8 +7,12 @@ pub mod model_pricing; pub mod oauth; pub mod ollama; pub mod openai; +pub mod openai_utils; pub mod utils; pub mod google; +pub mod groq; #[cfg(test)] pub mod mock; +#[cfg(test)] +pub mod mock_server; diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index c769eef4f4e5..c23d77df7e7f 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -17,6 +17,8 @@ use mcp_core::content::Content; use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; +pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest"; + pub struct AnthropicProvider { client: Client, config: AnthropicProviderConfig, diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 67c49282dc5f..94f6d585d3eb 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -14,6 +14,7 @@ pub enum ProviderConfig { Ollama(OllamaProviderConfig), Anthropic(AnthropicProviderConfig), Google(GoogleProviderConfig), + Groq(GroqProviderConfig), } /// Configuration for model-specific settings and limits @@ -222,6 +223,19 @@ impl ProviderModelConfig for GoogleProviderConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GroqProviderConfig { + pub host: String, + pub api_key: String, + pub model: ModelConfig, +} + +impl ProviderModelConfig for GroqProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 460341d450e5..6a2670392ea1 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::{json, Value}; use std::time::Duration; @@ -8,13 +8,16 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::{cost, model_pricing_for}; use super::oauth; -use super::utils::{ - check_bedrock_context_length_error, check_openai_context_length_error, get_model, - messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, -}; +use super::utils::{check_bedrock_context_length_error, get_model, handle_response}; use crate::message::Message; +use crate::providers::openai_utils::{ + check_openai_context_length_error, get_openai_usage, messages_to_openai_spec, + openai_response_to_message, tools_to_openai_spec, +}; use mcp_core::tool::Tool; +pub const DATABRICKS_DEFAULT_MODEL: &str = "claude-3-5-sonnet-2"; + pub struct DatabricksProvider { client: Client, config: DatabricksProviderConfig, @@ -46,30 +49,7 @@ impl DatabricksProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -88,18 +68,7 @@ impl DatabricksProvider { .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => { - let status = response.status(); - let err_text = response.text().await.unwrap_or_default(); - Err(anyhow!("Request failed: {}: {}", status, err_text)) - } - } + handle_response(payload, response).await? } } @@ -112,7 +81,7 @@ impl Provider for DatabricksProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Prepare messages and tools - let messages_spec = messages_to_openai_spec(messages, &self.config.image_format); + let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false); let tools_spec = if !tools.is_empty() { tools_to_openai_spec(tools)? } else { @@ -179,6 +148,9 @@ mod tests { use super::*; use crate::message::MessageContent; use crate::providers::configs::ModelConfig; + use crate::providers::mock_server::{ + create_mock_open_ai_response, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOTAL_TOKENS, + }; use wiremock::matchers::{body_json, header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -188,19 +160,7 @@ mod tests { let mock_server = MockServer::start().await; // Mock response for completion - let mock_response = json!({ - "choices": [{ - "message": { - "role": "assistant", - "content": "Hello!" - } - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 25, - "total_tokens": 35 - } - }); + let mock_response = create_mock_open_ai_response("my-databricks-model", "Hello!"); // Expected request body let system = "You are a helpful assistant."; @@ -244,9 +204,9 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(reply_usage.usage.input_tokens, Some(10)); - assert_eq!(reply_usage.usage.output_tokens, Some(25)); - assert_eq!(reply_usage.usage.total_tokens, Some(35)); + assert_eq!(reply_usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(reply_usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(reply_usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index f5a9c0931dfe..58ad7513bef2 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,7 +1,7 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, - databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider, - openai::OpenAiProvider, + databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, + ollama::OllamaProvider, openai::OpenAiProvider, }; use anyhow::Result; use strum_macros::EnumIter; @@ -13,6 +13,7 @@ pub enum ProviderType { Ollama, Anthropic, Google, + Groq, } pub fn get_provider(config: ProviderConfig) -> Result> { @@ -26,5 +27,6 @@ pub fn get_provider(config: ProviderConfig) -> Result Ok(Box::new(GoogleProvider::new(google_config)?)), + ProviderConfig::Groq(groq_config) => Ok(Box::new(GroqProvider::new(groq_config)?)), } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index e5ab052de328..d96b681add87 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,12 +3,11 @@ use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; use crate::providers::utils::{ - is_valid_function_name, sanitize_function_name, unescape_json_values, + handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values, }; -use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::{json, Map, Value}; use std::time::Duration; @@ -66,18 +65,7 @@ impl GoogleProvider { .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } fn messages_to_google_spec(&self, messages: &[Message]) -> Vec { @@ -361,6 +349,13 @@ impl Provider for GoogleProvider { mod tests { use super::*; use crate::errors::AgentResult; + use crate::providers::mock_server::{ + create_mock_google_ai_response, create_mock_google_ai_response_with_tools, + create_test_tool, get_expected_function_call_arguments, setup_mock_server, + TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; + use wiremock::MockServer; + fn set_up_provider() -> GoogleProvider { let provider_config = GoogleProviderConfig { host: "dummy_host".to_string(), @@ -617,4 +612,88 @@ mod tests { panic!("Expected valid tool request"); } } + + async fn _setup_mock_server( + model_name: &str, + response_body: Value, + ) -> (MockServer, GoogleProvider) { + let path_url = format!("/v1beta/models/{}:generateContent", model_name); + let mock_server = setup_mock_server(&path_url, response_body).await; + let config = GoogleProviderConfig { + host: mock_server.uri(), + api_key: "test_api_key".to_string(), + model: ModelConfig::new(GOOGLE_DEFAULT_MODEL.to_string()), + }; + + let provider = GoogleProvider::new(config).unwrap(); + (mock_server, provider) + } + + #[tokio::test] + async fn test_complete_basic() -> anyhow::Result<()> { + let model_name = "gemini-1.5-flash"; + // Mock response for normal completion + let response_body = + create_mock_google_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_server(model_name, response_body).await; + + // Prepare input messages + let messages = vec![Message::user().with_text("Hello?")]; + + // Call the complete method + let (message, usage) = provider + .complete("You are a helpful assistant.", &messages, &[]) + .await?; + + // Assert the response + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello! How can I assist you today?"); + } else { + panic!("Expected Text content"); + } + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, None); + + Ok(()) + } + + #[tokio::test] + async fn test_complete_tool_request() -> anyhow::Result<()> { + let model_name = "gemini-1.5-flash"; + // Mock response for tool calling + let response_body = create_mock_google_ai_response_with_tools("gpt-4o"); + + let (_, provider) = _setup_mock_server(model_name, response_body).await; + + // Input messages + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Call the complete method + let (message, usage) = provider + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) + .await?; + + // Assert the response + if let MessageContent::ToolRequest(tool_request) = &message.content[0] { + let tool_call = tool_request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); + } else { + panic!("Expected ToolCall content"); + } + + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + + Ok(()) + } } diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs new file mode 100644 index 000000000000..b0d1be1dc2a7 --- /dev/null +++ b/crates/goose/src/providers/groq.rs @@ -0,0 +1,166 @@ +use crate::message::Message; +use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; +use crate::providers::openai_utils::{ + create_openai_request_payload, get_openai_usage, openai_response_to_message, +}; +use crate::providers::utils::{get_model, handle_response}; +use async_trait::async_trait; +use mcp_core::Tool; +use reqwest::Client; +use serde_json::Value; +use std::time::Duration; + +pub const GROQ_API_HOST: &str = "https://api.groq.com"; +pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile"; + +pub struct GroqProvider { + client: Client, + config: GroqProviderConfig, +} + +impl GroqProvider { + pub fn new(config: GroqProviderConfig) -> anyhow::Result { + let client = Client::builder() + .timeout(Duration::from_secs(600)) // 10 minutes timeout + .build()?; + + Ok(Self { client, config }) + } + + fn get_usage(data: &Value) -> anyhow::Result { + get_openai_usage(data) + } + + async fn post(&self, payload: Value) -> anyhow::Result { + let url = format!( + "{}/openai/v1/chat/completions", + self.config.host.trim_end_matches('/') + ); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.config.api_key)) + .json(&payload) + .send() + .await?; + handle_response(payload, response).await? + } +} + +#[async_trait] +impl Provider for GroqProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage)> { + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, true)?; + + let response = self.post(payload).await?; + + let message = openai_response_to_message(response.clone())?; + let usage = Self::get_usage(&response)?; + let model = get_model(&response); + + Ok((message, ProviderUsage::new(model, usage, None))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::MessageContent; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, + TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; + use wiremock::MockServer; + + async fn _setup_mock_server(response_body: Value) -> (MockServer, GroqProvider) { + let mock_server = setup_mock_server("/openai/v1/chat/completions", response_body).await; + let config = GroqProviderConfig { + host: mock_server.uri(), + api_key: "test_api_key".to_string(), + model: ModelConfig::new(GROQ_DEFAULT_MODEL.to_string()), + }; + + let provider = GroqProvider::new(config).unwrap(); + (mock_server, provider) + } + + #[tokio::test] + async fn test_complete_basic() -> anyhow::Result<()> { + let model_name = "gpt-4o"; + // Mock response for normal completion + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_server(response_body).await; + + // Prepare input messages + let messages = vec![Message::user().with_text("Hello?")]; + + // Call the complete method + let (message, usage) = provider + .complete("You are a helpful assistant.", &messages, &[]) + .await?; + + // Assert the response + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello! How can I assist you today?"); + } else { + panic!("Expected Text content"); + } + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, None); + + Ok(()) + } + + #[tokio::test] + async fn test_complete_tool_request() -> anyhow::Result<()> { + // Mock response for tool calling + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); + + let (_, provider) = _setup_mock_server(response_body).await; + + // Input messages + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Call the complete method + let (message, usage) = provider + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) + .await?; + + // Assert the response + if let MessageContent::ToolRequest(tool_request) = &message.content[0] { + let tool_call = tool_request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); + } else { + panic!("Expected ToolCall content"); + } + + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + + Ok(()) + } +} diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index 830c20601a82..270870b59dcd 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -1,16 +1,14 @@ +use super::base::ProviderUsage; +use crate::message::Message; +use crate::providers::base::{Provider, Usage}; +use crate::providers::configs::ModelConfig; use anyhow::Result; use async_trait::async_trait; +use mcp_core::tool::Tool; use rust_decimal_macros::dec; use std::sync::Arc; use std::sync::Mutex; -use crate::message::Message; -use crate::providers::base::{Provider, Usage}; -use crate::providers::configs::ModelConfig; -use mcp_core::tool::Tool; - -use super::base::ProviderUsage; - /// A mock provider that returns pre-configured responses for testing pub struct MockProvider { responses: Arc>>, diff --git a/crates/goose/src/providers/mock_server.rs b/crates/goose/src/providers/mock_server.rs new file mode 100644 index 000000000000..8712cb8635c7 --- /dev/null +++ b/crates/goose/src/providers/mock_server.rs @@ -0,0 +1,152 @@ +use mcp_core::Tool; +use serde_json::{json, Value}; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +pub const TEST_INPUT_TOKENS: i32 = 12; +pub const TEST_OUTPUT_TOKENS: i32 = 15; +pub const TEST_TOTAL_TOKENS: i32 = 27; +pub const TEST_TOOL_FUNCTION_NAME: &str = "get_weather"; +pub const TEST_TOOL_FUNCTION_ARGUMENTS: &str = "{\"location\":\"San Francisco, CA\"}"; + +pub async fn setup_mock_server(path_url: &str, response_body: Value) -> MockServer { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path(path_url)) + .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) + .mount(&mock_server) + .await; + mock_server +} + +pub async fn setup_mock_server_with_response_code( + path_url: &str, + response_code: u16, +) -> MockServer { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path(path_url)) + .respond_with(ResponseTemplate::new(response_code)) + .mount(&mock_server) + .await; + mock_server +} +pub fn create_mock_open_ai_response_with_tools(model_name: &str) -> Value { + json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": TEST_TOOL_FUNCTION_NAME, + "arguments": TEST_TOOL_FUNCTION_ARGUMENTS + } + }] + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": TEST_INPUT_TOKENS, + "completion_tokens": TEST_OUTPUT_TOKENS, + "total_tokens": TEST_TOTAL_TOKENS + }, + "model": model_name + }) +} + +pub fn create_mock_google_ai_response_with_tools(model_name: &str) -> Value { + json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": TEST_TOOL_FUNCTION_NAME, + "args":{ + "location": "San Francisco, CA" + } + + } + }], + "role": "model" + }, + "finishReason": "STOP" + }], + "modelVersion": model_name, + "usageMetadata": { + "candidatesTokenCount": TEST_OUTPUT_TOKENS, + "promptTokenCount": TEST_INPUT_TOKENS, + "totalTokenCount": TEST_TOTAL_TOKENS + } + }) +} + +pub fn create_mock_google_ai_response(model_name: &str, content: &str) -> Value { + json!({ + "candidates": [{ + "content": { + "parts": [{ + "text": content + }], + "role": "model" + }, + "finishReason": "STOP" + }], + "modelVersion": model_name, + "usageMetadata": { + "candidatesTokenCount": TEST_OUTPUT_TOKENS, + "promptTokenCount": TEST_INPUT_TOKENS, + "totalTokenCount": TEST_TOTAL_TOKENS + } + }) +} + +pub fn create_mock_open_ai_response(model_name: &str, content: &str) -> Value { + json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": content, + "tool_calls": null + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": TEST_INPUT_TOKENS, + "completion_tokens": TEST_OUTPUT_TOKENS, + "total_tokens": TEST_TOTAL_TOKENS + }, + "model": model_name + }) +} + +pub fn create_test_tool() -> Tool { + Tool::new( + "get_weather", + "Gets the current weather for a location", + json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. New York, NY" + } + }, + "required": ["location"] + }), + ) +} + +pub fn get_expected_function_call_arguments() -> Value { + json!({ + "location": "San Francisco, CA" + }) +} diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index feee301cf16f..540f9291f862 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,16 +1,15 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; -use super::utils::{ - get_model, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, - ImageFormat, -}; +use super::utils::{get_model, handle_response}; use crate::message::Message; -use anyhow::{anyhow, Result}; +use crate::providers::openai_utils::{ + create_openai_request_payload, get_openai_usage, openai_response_to_message, +}; +use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use reqwest::Client; -use reqwest::StatusCode; -use serde_json::{json, Value}; +use serde_json::Value; use std::time::Duration; pub const OLLAMA_HOST: &str = "http://localhost:11434"; @@ -31,30 +30,7 @@ impl OllamaProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -65,62 +41,24 @@ impl OllamaProvider { let response = self.client.post(&url).json(&payload).send().await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } } #[async_trait] impl Provider for OllamaProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { - let system_message = json!({ - "role": "system", - "content": system - }); - - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); - let tools_spec = tools_to_openai_spec(tools)?; - - let mut messages_array = vec![system_message]; - messages_array.extend(messages_spec); - - let mut payload = json!({ - "model": self.config.model.model_name, - "messages": messages_array - }); - - if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() - .insert("tools".to_string(), json!(tools_spec)); - } - if let Some(temp) = self.config.model.temperature { - payload - .as_object_mut() - .unwrap() - .insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = self.config.model.max_tokens { - payload - .as_object_mut() - .unwrap() - .insert("max_tokens".to_string(), json!(tokens)); - } + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, false)?; let response = self.post(payload).await?; @@ -132,28 +70,22 @@ impl Provider for OllamaProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } - - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() - } } #[cfg(test)] mod tests { use super::*; use crate::message::MessageContent; - use serde_json::json; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, + setup_mock_server_with_response_code, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, + TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; + use wiremock::MockServer; async fn _setup_mock_server(response_body: Value) -> (MockServer, OllamaProvider) { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) - .mount(&mock_server) - .await; - + let mock_server = setup_mock_server("/v1/chat/completions", response_body).await; // Create the OllamaProvider with the mock server's URL as the host let config = OllamaProviderConfig { host: mock_server.uri(), @@ -166,25 +98,10 @@ mod tests { #[tokio::test] async fn test_complete_basic() -> Result<()> { + let model_name = "gpt-4o"; // Mock response for normal completion - let response_body = json!({ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I assist you today?", - "tool_calls": null - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 15, - "total_tokens": 27 - } - }); + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); let (_, provider) = _setup_mock_server(response_body).await; @@ -202,9 +119,11 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.usage.input_tokens, Some(12)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(27)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, None); Ok(()) } @@ -212,82 +131,41 @@ mod tests { #[tokio::test] async fn test_complete_tool_request() -> Result<()> { // Mock response for tool calling - let response_body = json!({ - "id": "chatcmpl-tool", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": "call_h5d3s25w", - "type": "function", - "function": { - "name": "read_file", - "arguments": "{\"filename\":\"test.txt\"}" - } - }] - }, - "finish_reason": "tool_calls" - }], - "usage": { - "prompt_tokens": 63, - "completion_tokens": 70, - "total_tokens": 133 - } - }); + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); let (_, provider) = _setup_mock_server(response_body).await; // Input messages - let messages = vec![Message::user().with_text("Can you read the test.txt file?")]; - - // Define the tool - let tool = Tool::new( - "read_file", - "Read the content of a file", - json!({ - "type": "object", - "properties": { - "filename": { - "type": "string", - "description": "The name of the file to read" - } - }, - "required": ["filename"] - }), - ); + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; // Call the complete method let (message, usage) = provider - .complete("You are a helpful assistant.", &messages, &[tool]) + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) .await?; // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { let tool_call = tool_request.tool_call.as_ref().unwrap(); - assert_eq!(tool_call.name, "read_file"); - assert_eq!(tool_call.arguments, json!({"filename": "test.txt"})); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); } else { panic!("Expected ToolCall content"); } - assert_eq!(usage.usage.input_tokens, Some(63)); - assert_eq!(usage.usage.output_tokens, Some(70)); - assert_eq!(usage.usage.total_tokens, Some(133)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } #[tokio::test] async fn test_server_error() -> Result<()> { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(500)) - .mount(&mock_server) - .await; + let mock_server = setup_mock_server_with_response_code("/v1/chat/completions", 500).await; let config = OllamaProviderConfig { host: mock_server.uri(), diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 8b6a2748c9b7..f735b2209c64 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,8 +1,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use reqwest::Client; -use reqwest::StatusCode; -use serde_json::{json, Value}; +use serde_json::Value; use std::time::Duration; use super::base::ProviderUsage; @@ -11,14 +10,16 @@ use super::configs::OpenAiProviderConfig; use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; -use super::utils::get_model; -use super::utils::{ - check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, - tools_to_openai_spec, ImageFormat, -}; +use super::utils::{get_model, handle_response}; use crate::message::Message; +use crate::providers::openai_utils::{ + check_openai_context_length_error, create_openai_request_payload, get_openai_usage, + openai_response_to_message, +}; use mcp_core::tool::Tool; +pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; + pub struct OpenAiProvider { client: Client, config: OpenAiProviderConfig, @@ -34,30 +35,7 @@ impl OpenAiProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -74,23 +52,16 @@ impl OpenAiProvider { .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } } #[async_trait] impl Provider for OpenAiProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, @@ -98,48 +69,8 @@ impl Provider for OpenAiProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Not checking for o1 model here since system message is not supported by o1 - let system_message = json!({ - "role": "system", - "content": system - }); - - // Convert messages and tools to OpenAI format - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); - let tools_spec = if !tools.is_empty() { - tools_to_openai_spec(tools)? - } else { - vec![] - }; - - // Build payload - // create messages array with system message first - let mut messages_array = vec![system_message]; - messages_array.extend(messages_spec); - - let mut payload = json!({ - "model": self.config.model.model_name, - "messages": messages_array - }); - - // Add optional parameters - if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() - .insert("tools".to_string(), json!(tools_spec)); - } - if let Some(temp) = self.config.model.temperature { - payload - .as_object_mut() - .unwrap() - .insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = self.config.model.max_tokens { - payload - .as_object_mut() - .unwrap() - .insert("max_tokens".to_string(), json!(tokens)); - } + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, false)?; // Make request let response = self.post(payload).await?; @@ -160,10 +91,6 @@ impl Provider for OpenAiProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } - - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() - } } #[cfg(test)] @@ -171,18 +98,16 @@ mod tests { use super::*; use crate::message::MessageContent; use crate::providers::configs::ModelConfig; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, + TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; use rust_decimal_macros::dec; - use serde_json::json; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - async fn _setup_mock_server(response_body: Value) -> (MockServer, OpenAiProvider) { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) - .mount(&mock_server) - .await; + use wiremock::MockServer; + + async fn _setup_mock_response(response_body: Value) -> (MockServer, OpenAiProvider) { + let mock_server = setup_mock_server("/v1/chat/completions", response_body).await; // Create the OpenAiProvider with the mock server's URL as the host let config = OpenAiProviderConfig { @@ -197,28 +122,12 @@ mod tests { #[tokio::test] async fn test_complete_basic() -> Result<()> { + let model_name = "gpt-4o"; // Mock response for normal completion - let response_body = json!({ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I assist you today?", - "tool_calls": null - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 15, - "total_tokens": 27 - }, - "model": "gpt-4o" - }); - - let (_, provider) = _setup_mock_server(response_body).await; + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_response(response_body).await; // Prepare input messages let messages = vec![Message::user().with_text("Hello?")]; @@ -234,10 +143,10 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.usage.input_tokens, Some(12)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(27)); - assert_eq!(usage.model, "gpt-4o"); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); assert_eq!(usage.cost, Some(dec!(0.00018))); Ok(()) @@ -246,73 +155,36 @@ mod tests { #[tokio::test] async fn test_complete_tool_request() -> Result<()> { // Mock response for tool calling - let response_body = json!({ - "id": "chatcmpl-tool", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": "call_123", - "type": "function", - "function": { - "name": "get_weather", - "arguments": "{\"location\":\"San Francisco, CA\"}" - } - }] - }, - "finish_reason": "tool_calls" - }], - "usage": { - "prompt_tokens": 20, - "completion_tokens": 15, - "total_tokens": 35 - } - }); + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); - let (_, provider) = _setup_mock_server(response_body).await; + let (_, provider) = _setup_mock_response(response_body).await; // Input messages let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; // Define the tool using builder pattern - let tool = Tool::new( - "get_weather", - "Gets the current weather for a location", - json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. New York, NY" - } - }, - "required": ["location"] - }), - ); // Call the complete method let (message, usage) = provider - .complete("You are a helpful assistant.", &messages, &[tool]) + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) .await?; // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { let tool_call = tool_request.tool_call.as_ref().unwrap(); - assert_eq!(tool_call.name, "get_weather"); - assert_eq!( - tool_call.arguments, - json!({"location": "San Francisco, CA"}) - ); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); } else { panic!("Expected ToolCall content"); } - assert_eq!(usage.usage.input_tokens, Some(20)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(35)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs new file mode 100644 index 000000000000..0be8d8917448 --- /dev/null +++ b/crates/goose/src/providers/openai_utils.rs @@ -0,0 +1,604 @@ +use crate::errors::AgentError; +use crate::message::{Message, MessageContent}; +use crate::providers::base::Usage; +use crate::providers::configs::ModelConfig; +use crate::providers::utils::{ + convert_image, is_valid_function_name, sanitize_function_name, ContextLengthExceededError, + ImageFormat, +}; +use anyhow::{anyhow, Error}; +use mcp_core::{Content, Role, Tool, ToolCall}; +use serde_json::{json, Value}; + +/// Convert internal Message format to OpenAI's API message specification +/// some openai compatible endpoints use the anthropic image spec at the content level +/// even though the message structure is otherwise following openai, the enum switches this +pub fn messages_to_openai_spec( + messages: &[Message], + image_format: &ImageFormat, + concat_tool_response_contents: bool, +) -> Vec { + let mut messages_spec = Vec::new(); + for message in messages { + let mut converted = json!({ + "role": message.role + }); + + let mut output = Vec::new(); + + for content in &message.content { + match content { + MessageContent::Text(text) => { + if !text.text.is_empty() { + converted["content"] = json!(text.text); + } + } + MessageContent::ToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + let sanitized_name = sanitize_function_name(&tool_call.name); + let tool_calls = converted + .as_object_mut() + .unwrap() + .entry("tool_calls") + .or_insert(json!([])); + + tool_calls.as_array_mut().unwrap().push(json!({ + "id": request.id, + "type": "function", + "function": { + "name": sanitized_name, + "arguments": tool_call.arguments.to_string(), + } + })); + } + Err(e) => { + output.push(json!({ + "role": "tool", + "content": format!("Error: {}", e), + "tool_call_id": request.id + })); + } + }, + MessageContent::ToolResponse(response) => { + match &response.tool_result { + Ok(contents) => { + // Send only contents with no audience or with Assistant in the audience + let abridged: Vec<_> = contents + .iter() + .filter(|content| { + content + .audience() + .is_none_or(|audience| audience.contains(&Role::Assistant)) + }) + .map(|content| content.unannotated()) + .collect(); + + // Process all content, replacing images with placeholder text + let mut tool_content = Vec::new(); + let mut image_messages = Vec::new(); + + for content in abridged { + match content { + Content::Image(image) => { + // Add placeholder text in the tool response + tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); + + // Create a separate image message + image_messages.push(json!({ + "role": "user", + "content": [convert_image(&image, image_format)] + })); + } + _ => { + tool_content.push(content); + } + } + } + let tool_response_content: Value = match concat_tool_response_contents { + true => { + json!(tool_content + .iter() + .map(|content| match content { + Content::Text(text) => text.text.clone(), + _ => String::new(), + }) + .collect::>() + .join(" ")) + } + false => json!(tool_content), + }; + + // First add the tool response with all content + output.push(json!({ + "role": "tool", + "content": tool_response_content, + "tool_call_id": response.id + })); + // Then add any image messages that need to follow + output.extend(image_messages); + } + Err(e) => { + // A tool result error is shown as output so the model can interpret the error message + output.push(json!({ + "role": "tool", + "content": format!("The tool call returned the following error:\n{}", e), + "tool_call_id": response.id + })); + } + } + } + MessageContent::Image(image) => { + // Handle direct image content + converted["content"] = json!([convert_image(image, image_format)]); + } + } + } + + if converted.get("content").is_some() || converted.get("tool_calls").is_some() { + output.insert(0, converted); + } + messages_spec.extend(output); + } + + messages_spec +} + +/// Convert internal Tool format to OpenAI's API tool specification +pub fn tools_to_openai_spec(tools: &[Tool]) -> anyhow::Result> { + let mut tool_names = std::collections::HashSet::new(); + let mut result = Vec::new(); + + for tool in tools { + if !tool_names.insert(&tool.name) { + return Err(anyhow!("Duplicate tool name: {}", tool.name)); + } + + result.push(json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + })); + } + + Ok(result) +} + +/// Convert OpenAI's API response to internal Message format +pub fn openai_response_to_message(response: Value) -> anyhow::Result { + let original = response["choices"][0]["message"].clone(); + let mut content = Vec::new(); + + if let Some(text) = original.get("content") { + if let Some(text_str) = text.as_str() { + content.push(MessageContent::text(text_str)); + } + } + + if let Some(tool_calls) = original.get("tool_calls") { + if let Some(tool_calls_array) = tool_calls.as_array() { + for tool_call in tool_calls_array { + let id = tool_call["id"].as_str().unwrap_or_default().to_string(); + let function_name = tool_call["function"]["name"] + .as_str() + .unwrap_or_default() + .to_string(); + let arguments = tool_call["function"]["arguments"] + .as_str() + .unwrap_or_default() + .to_string(); + + if !is_valid_function_name(&function_name) { + let error = AgentError::ToolNotFound(format!( + "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", + function_name + )); + content.push(MessageContent::tool_request(id, Err(error))); + } else { + match serde_json::from_str::(&arguments) { + Ok(params) => { + content.push(MessageContent::tool_request( + id, + Ok(ToolCall::new(&function_name, params)), + )); + } + Err(e) => { + let error = AgentError::InvalidParameters(format!( + "Could not interpret tool use parameters for id {}: {}", + id, e + )); + content.push(MessageContent::tool_request(id, Err(error))); + } + } + } + } + } + } + + Ok(Message { + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content, + }) +} + +pub fn get_openai_usage(data: &Value) -> anyhow::Result { + let usage = data + .get("usage") + .ok_or_else(|| anyhow!("No usage data in response"))?; + + let input_tokens = usage + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let output_tokens = usage + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let total_tokens = usage + .get("total_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32) + .or_else(|| match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + _ => None, + }); + + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) +} + +pub fn create_openai_request_payload( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], + concat_tool_response_contents: bool, +) -> anyhow::Result { + let system_message = json!({ + "role": "system", + "content": system + }); + + let messages_spec = messages_to_openai_spec( + messages, + &ImageFormat::OpenAi, + concat_tool_response_contents, + ); + let tools_spec = tools_to_openai_spec(tools)?; + + let mut messages_array = vec![system_message]; + messages_array.extend(messages_spec); + + let mut payload = json!({ + "model": model_config.model_name, + "messages": messages_array + }); + + if !tools_spec.is_empty() { + payload + .as_object_mut() + .unwrap() + .insert("tools".to_string(), json!(tools_spec)); + } + if let Some(temp) = model_config.temperature { + payload + .as_object_mut() + .unwrap() + .insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = model_config.max_tokens { + payload + .as_object_mut() + .unwrap() + .insert("max_tokens".to_string(), json!(tokens)); + } + Ok(payload) +} + +pub fn check_openai_context_length_error(error: &Value) -> Option { + let code = error.get("code")?.as_str()?; + if code == "context_length_exceeded" || code == "string_above_max_length" { + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + Some(ContextLengthExceededError(message)) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mcp_core::content::Content; + use serde_json::json; + + const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ + "choices": [{ + "role": "assistant", + "message": { + "tool_calls": [{ + "id": "1", + "function": { + "name": "example_fn", + "arguments": "{\"param\": \"value\"}" + } + }] + } + }], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35 + } + }"#; + + #[test] + fn test_messages_to_openai_spec() -> anyhow::Result<()> { + let message = Message::user().with_text("Hello"); + let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi, false); + + assert_eq!(spec.len(), 1); + assert_eq!(spec[0]["role"], "user"); + assert_eq!(spec[0]["content"], "Hello"); + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec() -> anyhow::Result<()> { + let tool = Tool::new( + "test_tool", + "A test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let spec = tools_to_openai_spec(&[tool])?; + + assert_eq!(spec.len(), 1); + assert_eq!(spec[0]["type"], "function"); + assert_eq!(spec[0]["function"]["name"], "test_tool"); + Ok(()) + } + + #[test] + fn test_messages_to_openai_spec_complex() -> anyhow::Result<()> { + let mut messages = vec![ + Message::assistant().with_text("Hello!"), + Message::user().with_text("How are you?"), + Message::assistant().with_tool_request( + "tool1", + Ok(ToolCall::new("example", json!({"param1": "value1"}))), + ), + ]; + + // Get the ID from the tool request to use in the response + let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { + request.id.clone() + } else { + panic!("should be tool request"); + }; + + messages + .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); + + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, true); + + assert_eq!(spec.len(), 4); + assert_eq!(spec[0]["role"], "assistant"); + assert_eq!(spec[0]["content"], "Hello!"); + assert_eq!(spec[1]["role"], "user"); + assert_eq!(spec[1]["content"], "How are you?"); + assert_eq!(spec[2]["role"], "assistant"); + assert!(spec[2]["tool_calls"].is_array()); + assert_eq!(spec[3]["role"], "tool"); + assert_eq!(spec[3]["content"], "Result"); + assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]); + + Ok(()) + } + + #[test] + fn test_messages_to_openai_spec_not_concat_tool_response_content() -> anyhow::Result<()> { + let mut messages = vec![Message::assistant().with_tool_request( + "tool1", + Ok(ToolCall::new("example", json!({"param1": "value1"}))), + )]; + + // Get the ID from the tool request to use in the response + let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] { + request.id.clone() + } else { + panic!("should be tool request"); + }; + + messages + .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); + + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, false); + + assert_eq!(spec.len(), 2); + assert_eq!(spec[0]["role"], "assistant"); + assert!(spec[0]["tool_calls"].is_array()); + assert_eq!(spec[1]["role"], "tool"); + assert_eq!(spec[1]["content"][0]["text"], "Result"); + assert_eq!(spec[1]["tool_call_id"], spec[0]["tool_calls"][0]["id"]); + + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec_duplicate() -> anyhow::Result<()> { + let tool1 = Tool::new( + "test_tool", + "Test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let tool2 = Tool::new( + "test_tool", + "Test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let result = tools_to_openai_spec(&[tool1, tool2]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Duplicate tool name")); + + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec_empty() -> anyhow::Result<()> { + let spec = tools_to_openai_spec(&[])?; + assert!(spec.is_empty()); + Ok(()) + } + + #[test] + fn test_openai_response_to_message_text() -> anyhow::Result<()> { + let response = json!({ + "choices": [{ + "role": "assistant", + "message": { + "content": "Hello from John Cena!" + } + }], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35 + } + }); + + let message = openai_response_to_message(response)?; + assert_eq!(message.content.len(), 1); + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello from John Cena!"); + } else { + panic!("Expected Text content"); + } + assert!(matches!(message.role, Role::Assistant)); + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_valid_toolrequest() -> anyhow::Result<()> { + let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + let message = openai_response_to_message(response)?; + + assert_eq!(message.content.len(), 1); + if let MessageContent::ToolRequest(request) = &message.content[0] { + let tool_call = request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, "example_fn"); + assert_eq!(tool_call.arguments, json!({"param": "value"})); + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_invalid_func_name() -> anyhow::Result<()> { + let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = + json!("invalid fn"); + + let message = openai_response_to_message(response)?; + + if let MessageContent::ToolRequest(request) = &message.content[0] { + match &request.tool_call { + Err(AgentError::ToolNotFound(msg)) => { + assert!(msg.starts_with("The provided function name")); + } + _ => panic!("Expected ToolNotFound error"), + } + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_json_decode_error() -> anyhow::Result<()> { + let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] = + json!("invalid json {"); + + let message = openai_response_to_message(response)?; + + if let MessageContent::ToolRequest(request) = &message.content[0] { + match &request.tool_call { + Err(AgentError::InvalidParameters(msg)) => { + assert!(msg.starts_with("Could not interpret tool use parameters")); + } + _ => panic!("Expected InvalidParameters error"), + } + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_check_openai_context_length_error() { + let error = json!({ + "code": "context_length_exceeded", + "message": "This message is too long" + }); + + let result = check_openai_context_length_error(&error); + assert!(result.is_some()); + assert_eq!( + result.unwrap().to_string(), + "Context length exceeded. Message: This message is too long" + ); + + let error = json!({ + "code": "other_error", + "message": "Some other error" + }); + + let result = check_openai_context_length_error(&error); + assert!(result.is_none()); + } +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index f3bd5d8ee516..ee52a6059d17 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,13 +1,10 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use regex::Regex; +use reqwest::{Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; -use crate::errors::AgentError; -use crate::message::{Message, MessageContent}; -use mcp_core::content::{Content, ImageContent}; -use mcp_core::role::Role; -use mcp_core::tool::{Tool, ToolCall}; +use mcp_core::content::ImageContent; #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub enum ImageFormat { @@ -15,124 +12,6 @@ pub enum ImageFormat { Anthropic, } -/// Convert internal Message format to OpenAI's API message specification -/// some openai compatible endpoints use the anthropic image spec at the content level -/// even though the message structure is otherwise following openai, the enum switches this -pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) -> Vec { - let mut messages_spec = Vec::new(); - - for message in messages { - let mut converted = json!({ - "role": message.role - }); - - let mut output = Vec::new(); - - for content in &message.content { - match content { - MessageContent::Text(text) => { - if !text.text.is_empty() { - converted["content"] = json!(text.text); - } - } - MessageContent::ToolRequest(request) => match &request.tool_call { - Ok(tool_call) => { - let sanitized_name = sanitize_function_name(&tool_call.name); - let tool_calls = converted - .as_object_mut() - .unwrap() - .entry("tool_calls") - .or_insert(json!([])); - - tool_calls.as_array_mut().unwrap().push(json!({ - "id": request.id, - "type": "function", - "function": { - "name": sanitized_name, - "arguments": tool_call.arguments.to_string(), - } - })); - } - Err(e) => { - output.push(json!({ - "role": "tool", - "content": format!("Error: {}", e), - "tool_call_id": request.id - })); - } - }, - MessageContent::ToolResponse(response) => { - match &response.tool_result { - Ok(contents) => { - // Send only contents with no audience or with Assistant in the audience - let abridged: Vec<_> = contents - .iter() - .filter(|content| { - content - .audience() - .is_none_or(|audience| audience.contains(&Role::Assistant)) - }) - .map(|content| content.unannotated()) - .collect(); - - // Process all content, replacing images with placeholder text - let mut tool_content = Vec::new(); - let mut image_messages = Vec::new(); - - for content in abridged { - match content { - Content::Image(image) => { - // Add placeholder text in the tool response - tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); - - // Create a separate image message - image_messages.push(json!({ - "role": "user", - "content": [convert_image(&image, image_format)] - })); - } - _ => { - tool_content.push(content); - } - } - } - - // First add the tool response with all content - output.push(json!({ - "role": "tool", - "content": tool_content, - "tool_call_id": response.id - })); - - // Then add any image messages that need to follow - output.extend(image_messages); - } - Err(e) => { - // A tool result error is shown as output so the model can interpret the error message - output.push(json!({ - "role": "tool", - "content": format!("The tool call returned the following error:\n{}", e), - "tool_call_id": response.id - })); - } - } - } - MessageContent::Image(image) => { - // Handle direct image content - converted["content"] = json!([convert_image(image, image_format)]); - } - } - } - - if converted.get("content").is_some() || converted.get("tool_calls").is_some() { - output.insert(0, converted); - } - messages_spec.extend(output); - } - - messages_spec -} - /// Convert an image content into an image json based on format pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value { match image_format { @@ -153,84 +32,18 @@ pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value } } -/// Convert internal Tool format to OpenAI's API tool specification -pub fn tools_to_openai_spec(tools: &[Tool]) -> Result> { - let mut tool_names = std::collections::HashSet::new(); - let mut result = Vec::new(); - - for tool in tools { - if !tool_names.insert(&tool.name) { - return Err(anyhow!("Duplicate tool name: {}", tool.name)); +pub async fn handle_response(payload: Value, response: Response) -> Result, Error> { + Ok(match response.status() { + StatusCode::OK => Ok(response.json().await?), + status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { + // Implement retry logic here if needed + Err(anyhow!("Server error: {}", status)) } - - result.push(json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema, - } - })); - } - - Ok(result) -} - -/// Convert OpenAI's API response to internal Message format -pub fn openai_response_to_message(response: Value) -> Result { - let original = response["choices"][0]["message"].clone(); - let mut content = Vec::new(); - - if let Some(text) = original.get("content") { - if let Some(text_str) = text.as_str() { - content.push(MessageContent::text(text_str)); - } - } - - if let Some(tool_calls) = original.get("tool_calls") { - if let Some(tool_calls_array) = tool_calls.as_array() { - for tool_call in tool_calls_array { - let id = tool_call["id"].as_str().unwrap_or_default().to_string(); - let function_name = tool_call["function"]["name"] - .as_str() - .unwrap_or_default() - .to_string(); - let arguments = tool_call["function"]["arguments"] - .as_str() - .unwrap_or_default() - .to_string(); - - if !is_valid_function_name(&function_name) { - let error = AgentError::ToolNotFound(format!( - "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", - function_name - )); - content.push(MessageContent::tool_request(id, Err(error))); - } else { - match serde_json::from_str::(&arguments) { - Ok(params) => { - content.push(MessageContent::tool_request( - id, - Ok(ToolCall::new(&function_name, params)), - )); - } - Err(e) => { - let error = AgentError::InvalidParameters(format!( - "Could not interpret tool use parameters for id {}: {}", - id, e - )); - content.push(MessageContent::tool_request(id, Err(error))); - } - } - } - } - } - } - - Ok(Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content, + _ => Err(anyhow!( + "Request failed: {}\nPayload: {}", + response.status(), + payload + )), }) } @@ -246,21 +59,7 @@ pub fn is_valid_function_name(name: &str) -> bool { #[derive(Debug, thiserror::Error)] #[error("Context length exceeded. Message: {0}")] -pub struct ContextLengthExceededError(String); - -pub fn check_openai_context_length_error(error: &Value) -> Option { - let code = error.get("code")?.as_str()?; - if code == "context_length_exceeded" || code == "string_above_max_length" { - let message = error - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown error") - .to_string(); - Some(ContextLengthExceededError(message)) - } else { - None - } -} +pub struct ContextLengthExceededError(pub String); pub fn check_bedrock_context_length_error(error: &Value) -> Option { let external_message = error @@ -319,65 +118,8 @@ pub fn unescape_json_values(value: &Value) -> Value { #[cfg(test)] mod tests { use super::*; - use mcp_core::content::Content; use serde_json::json; - const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ - "choices": [{ - "role": "assistant", - "message": { - "tool_calls": [{ - "id": "1", - "function": { - "name": "example_fn", - "arguments": "{\"param\": \"value\"}" - } - }] - } - }], - "usage": { - "input_tokens": 10, - "output_tokens": 25, - "total_tokens": 35 - } - }"#; - - #[test] - fn test_messages_to_openai_spec() -> Result<()> { - let message = Message::user().with_text("Hello"); - let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - assert_eq!(spec[0]["content"], "Hello"); - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec() -> Result<()> { - let tool = Tool::new( - "test_tool", - "A test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let spec = tools_to_openai_spec(&[tool])?; - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["type"], "function"); - assert_eq!(spec[0]["function"]["name"], "test_tool"); - Ok(()) - } - #[test] fn test_sanitize_function_name() { assert_eq!(sanitize_function_name("hello-world"), "hello-world"); @@ -393,207 +135,6 @@ mod tests { assert!(!is_valid_function_name("hello@world")); } - #[test] - fn test_messages_to_openai_spec_complex() -> Result<()> { - let mut messages = vec![ - Message::assistant().with_text("Hello!"), - Message::user().with_text("How are you?"), - Message::assistant().with_tool_request( - "tool1", - Ok(ToolCall::new("example", json!({"param1": "value1"}))), - ), - ]; - - // Get the ID from the tool request to use in the response - let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { - request.id.clone() - } else { - panic!("should be tool request"); - }; - - messages - .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); - - let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 4); - assert_eq!(spec[0]["role"], "assistant"); - assert_eq!(spec[0]["content"], "Hello!"); - assert_eq!(spec[1]["role"], "user"); - assert_eq!(spec[1]["content"], "How are you?"); - assert_eq!(spec[2]["role"], "assistant"); - assert!(spec[2]["tool_calls"].is_array()); - assert_eq!(spec[3]["role"], "tool"); - assert_eq!( - spec[3]["content"], - json!([{"text": "Result", "type": "text"}]) - ); - assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]); - - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec_duplicate() -> Result<()> { - let tool1 = Tool::new( - "test_tool", - "Test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let tool2 = Tool::new( - "test_tool", - "Test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let result = tools_to_openai_spec(&[tool1, tool2]); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Duplicate tool name")); - - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec_empty() -> Result<()> { - let spec = tools_to_openai_spec(&[])?; - assert!(spec.is_empty()); - Ok(()) - } - - #[test] - fn test_openai_response_to_message_text() -> Result<()> { - let response = json!({ - "choices": [{ - "role": "assistant", - "message": { - "content": "Hello from John Cena!" - } - }], - "usage": { - "input_tokens": 10, - "output_tokens": 25, - "total_tokens": 35 - } - }); - - let message = openai_response_to_message(response)?; - assert_eq!(message.content.len(), 1); - if let MessageContent::Text(text) = &message.content[0] { - assert_eq!(text.text, "Hello from John Cena!"); - } else { - panic!("Expected Text content"); - } - assert!(matches!(message.role, Role::Assistant)); - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_valid_toolrequest() -> Result<()> { - let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - let message = openai_response_to_message(response)?; - - assert_eq!(message.content.len(), 1); - if let MessageContent::ToolRequest(request) = &message.content[0] { - let tool_call = request.tool_call.as_ref().unwrap(); - assert_eq!(tool_call.name, "example_fn"); - assert_eq!(tool_call.arguments, json!({"param": "value"})); - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_invalid_func_name() -> Result<()> { - let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = - json!("invalid fn"); - - let message = openai_response_to_message(response)?; - - if let MessageContent::ToolRequest(request) = &message.content[0] { - match &request.tool_call { - Err(AgentError::ToolNotFound(msg)) => { - assert!(msg.starts_with("The provided function name")); - } - _ => panic!("Expected ToolNotFound error"), - } - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_json_decode_error() -> Result<()> { - let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] = - json!("invalid json {"); - - let message = openai_response_to_message(response)?; - - if let MessageContent::ToolRequest(request) = &message.content[0] { - match &request.tool_call { - Err(AgentError::InvalidParameters(msg)) => { - assert!(msg.starts_with("Could not interpret tool use parameters")); - } - _ => panic!("Expected InvalidParameters error"), - } - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_check_openai_context_length_error() { - let error = json!({ - "code": "context_length_exceeded", - "message": "This message is too long" - }); - - let result = check_openai_context_length_error(&error); - assert!(result.is_some()); - assert_eq!( - result.unwrap().to_string(), - "Context length exceeded. Message: This message is too long" - ); - - let error = json!({ - "code": "other_error", - "message": "Some other error" - }); - - let result = check_openai_context_length_error(&error); - assert!(result.is_none()); - } - #[test] fn test_check_bedrock_context_length_error() { let error = json!({ diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 0a7a5d1127cc..345b0acc68bd 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -15,6 +15,7 @@ const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemma-2-tokenizer"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; +const LLAMA_TOKENIZER_KEY: &str = "Xenova--llama3-tokenizer"; impl Default for TokenCounter { fn default() -> Self { @@ -53,6 +54,8 @@ impl TokenCounter { GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY, GOOGLE_TOKENIZER_KEY, + QWEN_TOKENIZER_KEY, + LLAMA_TOKENIZER_KEY, ] { counter.load_tokenizer(tokenizer_key); } @@ -71,6 +74,8 @@ impl TokenCounter { QWEN_TOKENIZER_KEY } else if model_name.contains("gemini") { GOOGLE_TOKENIZER_KEY + } else if model_name.contains("llama") { + LLAMA_TOKENIZER_KEY } else { // default GPT_4O_TOKENIZER_KEY