diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index b50cc65a04c3..4821c57c8586 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -34,8 +34,10 @@ reqwest = "0.11.27" rand = "0.8.5" async-trait = "0.1" rustyline = "15.0.0" +rust_decimal = "1.36.0" +rust_decimal_macros = "1.36.0" [dev-dependencies] tempfile = "3" -temp-env = "0.3.6" +temp-env = { version = "0.3.6", features = ["async_closure"] } diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index 2ca850215904..eb0833489159 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -2,14 +2,14 @@ use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; use goose::{ - agent::Agent as GooseAgent, message::Message, providers::base::Usage, systems::System, + agent::Agent as GooseAgent, message::Message, providers::base::ProviderUsage, systems::System, }; #[async_trait] pub trait Agent { fn add_system(&mut self, system: Box); async fn reply(&self, messages: &[Message]) -> Result>>; - fn total_usage(&self) -> Usage; + async fn usage(&self) -> Result>; } #[async_trait] @@ -22,7 +22,7 @@ impl Agent for GooseAgent { self.reply(messages).await } - fn total_usage(&self) -> Usage { - self.total_usage() + async fn usage(&self) -> Result> { + self.usage().await } } diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index 2466c0aa8ca6..542431ac3c0e 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -1,7 +1,9 @@ +use std::vec; + use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{message::Message, systems::System}; +use goose::{message::Message, providers::base::ProviderUsage, systems::System}; use crate::agents::agent::Agent; @@ -15,7 +17,11 @@ impl Agent for MockAgent { Ok(Box::pin(futures::stream::empty())) } - fn total_usage(&self) -> goose::providers::base::Usage { - goose::providers::base::Usage::default() + async fn usage(&self) -> Result> { + Ok(vec![ProviderUsage::new( + "mock".to_string(), + Default::default(), + None, + )]) } } diff --git a/crates/goose-cli/src/log_usage.rs b/crates/goose-cli/src/log_usage.rs index eea5605e836d..8c21d2e980de 100644 --- a/crates/goose-cli/src/log_usage.rs +++ b/crates/goose-cli/src/log_usage.rs @@ -1,12 +1,12 @@ -use goose::providers::base::Usage; +use goose::providers::base::ProviderUsage; #[derive(Debug, serde::Serialize, serde::Deserialize)] struct SessionLog { session_file: String, - usage: goose::providers::base::Usage, + usage: Vec, } -pub fn log_usage(session_file: String, usage: Usage) { +pub fn log_usage(session_file: String, usage: Vec) { let log = SessionLog { session_file, usage, @@ -49,12 +49,14 @@ pub fn log_usage(session_file: String, usage: Usage) { #[cfg(test)] mod tests { - use goose::providers::base::Usage; + use goose::providers::base::{ProviderUsage, Usage}; + use rust_decimal_macros::dec; use crate::{ log_usage::{log_usage, SessionLog}, test_helpers::run_with_tmp_dir, }; + #[test] fn test_session_logging() { run_with_tmp_dir(|| { @@ -63,7 +65,11 @@ mod tests { log_usage( "path.txt".to_string(), - Usage::new(Some(10), Some(20), Some(30)), + vec![ProviderUsage::new( + "model".to_string(), + Usage::new(Some(10), Some(20), Some(30)), + Some(dec!(0.5)), + )], ); // Check if log file exists and contains the expected content @@ -75,9 +81,11 @@ mod tests { serde_json::from_str(log_content.lines().last().unwrap()).unwrap(); assert!(log.session_file.contains("path.txt")); - assert_eq!(log.usage.input_tokens, Some(10)); - assert_eq!(log.usage.output_tokens, Some(20)); - assert_eq!(log.usage.total_tokens, Some(30)); + assert_eq!(log.usage[0].usage.input_tokens, Some(10)); + assert_eq!(log.usage[0].usage.output_tokens, Some(20)); + assert_eq!(log.usage[0].usage.total_tokens, Some(30)); + assert_eq!(log.usage[0].model, "model"); + assert_eq!(log.usage[0].cost, Some(dec!(0.5))); }) } } diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index b03470d2bffb..cc1ce4b96899 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -146,7 +146,7 @@ impl<'a> Session<'a> { self.agent_process_messages().await; self.prompt.hide_busy(); } - self.close_session(); + self.close_session().await; Ok(()) } @@ -162,7 +162,7 @@ impl<'a> Session<'a> { self.agent_process_messages().await; - self.close_session(); + self.close_session().await; Ok(()) } @@ -312,7 +312,7 @@ We've removed the conversation up to the most recent user message self.agent.add_system(goosehints_system); } - fn close_session(&mut self) { + async fn close_session(&mut self) { self.prompt.render(raw_message( format!( "Closing session. Recorded to {}\n", @@ -321,6 +321,10 @@ We've removed the conversation up to the most recent user message .as_str(), )); self.prompt.close(); + match self.agent.usage().await { + Ok(usage) => log_usage(self.session_file.to_string_lossy().to_string(), usage), + Err(e) => eprintln!("Failed to collect total provider usage: {}", e), + } } pub fn session_file(&self) -> PathBuf { @@ -328,15 +332,6 @@ We've removed the conversation up to the most recent user message } } -impl<'a> Drop for Session<'a> { - fn drop(&mut self) { - log_usage( - self.session_file.to_string_lossy().to_string(), - self.agent.total_usage(), - ); - } -} - fn raw_message(content: &str) -> Box { Box::new(Message::assistant().with_text(content)) } @@ -348,7 +343,7 @@ mod tests { use crate::agents::mock_agent::MockAgent; use crate::prompt::{self, Input}; - use crate::test_helpers::run_with_tmp_dir; + use crate::test_helpers::{run_with_tmp_dir, run_with_tmp_dir_async}; use super::*; use goose::errors::AgentResult; @@ -808,19 +803,17 @@ mod tests { }) } - #[test] - fn test_session_logging() -> Result<()> { - run_with_tmp_dir(|| { + #[tokio::test] + async fn test_session_logging() -> Result<()> { + run_with_tmp_dir_async(|| async { // Create a test session - let session = create_test_session(); + let mut session = create_test_session(); let session_file = session.session_file.clone(); // Create a log directory let home_dir = dirs::home_dir().unwrap(); let log_dir = home_dir.join(".config").join("goose").join("logs"); - std::fs::create_dir_all(&log_dir)?; - // Drop the session to trigger logging - drop(session); + session.close_session().await; // Check if log file exists and contains the expected content let log_file = log_dir.join("goose.log"); @@ -834,6 +827,7 @@ mod tests { Ok(()) }) + .await } fn assert_last_prompt_text(session: &Session, expected_text: &str) { diff --git a/crates/goose-cli/src/test_helpers.rs b/crates/goose-cli/src/test_helpers.rs index e3cee5cbcf81..8e611da0b15b 100644 --- a/crates/goose-cli/src/test_helpers.rs +++ b/crates/goose-cli/src/test_helpers.rs @@ -1,20 +1,52 @@ +/// Helper function to set up a temporary home directory for testing, returns path of that temp dir. +/// Also creates a default profiles.json to avoid obscure test failures when there are no profiles. #[cfg(test)] pub fn run_with_tmp_dir T, T>(func: F) -> T { use std::ffi::OsStr; - use std::fs; use tempfile::tempdir; - // Helper function to set up a temporary home directory for testing, returns path of that temp dir. - // Also creates a default profiles.json to avoid obscure test failures when there are no profiles. - let temp_dir = tempdir().unwrap(); - // std::env::set_var("HOME", temp_dir.path()); + let temp_dir_path = temp_dir.path().to_path_buf(); + setup_profile(&temp_dir_path); + temp_env::with_vars( + [ + ("HOME", Some(temp_dir_path.as_os_str())), + ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), + ], + func, + ) +} + +#[cfg(test)] +pub async fn run_with_tmp_dir_async(func: F) -> T +where + F: FnOnce() -> Fut, + Fut: std::future::Future, +{ + use std::ffi::OsStr; + use tempfile::tempdir; + + let temp_dir = tempdir().unwrap(); let temp_dir_path = temp_dir.path().to_path_buf(); - println!( - "Created temporary home directory: {}", - temp_dir_path.display() - ); + setup_profile(&temp_dir_path); + + temp_env::async_with_vars( + [ + ("HOME", Some(temp_dir_path.as_os_str())), + ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), + ], + func(), + ) + .await +} + +#[cfg(test)] +use std::path::PathBuf; +#[cfg(test)] +fn setup_profile(temp_dir_path: &PathBuf) { + use std::fs; + let profile_path = temp_dir_path .join(".config") .join("goose") @@ -31,12 +63,4 @@ pub fn run_with_tmp_dir T, T>(func: F) -> T { } }"#; fs::write(&profile_path, profile).unwrap(); - - temp_env::with_vars( - [ - ("HOME", Some(temp_dir_path.as_os_str())), - ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), - ], - func, - ) } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 97be4e6e6da5..c5e07e2878d0 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -391,7 +391,10 @@ mod tests { use super::*; use goose::{ agent::Agent, - providers::{base::Provider, configs::OpenAiProviderConfig}, + providers::{ + base::{Provider, ProviderUsage, Usage}, + configs::OpenAiProviderConfig, + }, }; use mcp_core::tool::Tool; @@ -406,16 +409,12 @@ mod tests { _system_prompt: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, goose::providers::base::Usage), anyhow::Error> { + ) -> Result<(Message, ProviderUsage), anyhow::Error> { Ok(( Message::assistant().with_text("Mock response"), - goose::providers::base::Usage::default(), + ProviderUsage::new("mock".to_string(), Usage::default(), None), )) } - - fn total_usage(&self) -> goose::providers::base::Usage { - goose::providers::base::Usage::default() - } } #[test] diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 0b39cc5e1d3b..55bf58c3f10a 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -46,6 +46,8 @@ kill_tree = "0.2.4" keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } shellexpand = "3.1.0" +rust_decimal = "1.36.0" +rust_decimal_macros = "1.36.0" [dev-dependencies] sysinfo = "0.32.1" diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index 36dff0290d17..ce431eb1049e 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -41,9 +41,9 @@ async fn main() -> Result<()> { } println!("\nToken Usage:"); println!("------------"); - println!("Input tokens: {:?}", usage.input_tokens); - println!("Output tokens: {:?}", usage.output_tokens); - println!("Total tokens: {:?}", usage.total_tokens); + println!("Input tokens: {:?}", usage.usage.input_tokens); + println!("Output tokens: {:?}", usage.usage.output_tokens); + println!("Total tokens: {:?}", usage.usage.total_tokens); Ok(()) } diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 3756720283bd..25baacb744e8 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -89,9 +89,9 @@ async fn main() -> Result<()> { } println!("\nToken Usage:"); println!("------------"); - println!("Input tokens: {:?}", usage.input_tokens); - println!("Output tokens: {:?}", usage.output_tokens); - println!("Total tokens: {:?}", usage.total_tokens); + println!("Input tokens: {:?}", usage.usage.input_tokens); + println!("Output tokens: {:?}", usage.usage.output_tokens); + println!("Total tokens: {:?}", usage.usage.total_tokens); } Ok(()) diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs index 39fa446713ae..bf82b9823cee 100644 --- a/crates/goose/src/agent.rs +++ b/crates/goose/src/agent.rs @@ -1,13 +1,15 @@ use anyhow::Result; use async_stream; use futures::stream::BoxStream; +use rust_decimal_macros::dec; use serde_json::json; use std::collections::HashMap; +use tokio::sync::Mutex; use crate::errors::{AgentError, AgentResult}; use crate::message::{Message, ToolRequest}; use crate::prompt_template::load_prompt_file; -use crate::providers::base::Provider; +use crate::providers::base::{Provider, ProviderUsage}; use crate::systems::System; use crate::token_counter::TokenCounter; use mcp_core::{Content, Resource, Tool, ToolCall}; @@ -56,6 +58,7 @@ impl SystemStatus { pub struct Agent { systems: Vec>, provider: Box, + provider_usage: Mutex>, } #[allow(dead_code)] @@ -65,6 +68,7 @@ impl Agent { Self { systems: Vec::new(), provider, + provider_usage: Mutex::new(Vec::new()), } } @@ -339,11 +343,12 @@ impl Agent { loop { // Get completion from provider - let (response, _) = self.provider.complete( + let (response, usage) = self.provider.complete( &system_prompt, &messages, &tools, ).await?; + self.provider_usage.lock().await.push(usage); // The assistant's response is added in rewrite_messages_on_tool_response // Yield the assistant's response @@ -396,8 +401,32 @@ impl Agent { })) } - pub fn total_usage(&self) -> crate::providers::base::Usage { - self.provider.total_usage() + pub async fn usage(&self) -> Result> { + let provider_usage = self.provider_usage.lock().await.clone(); + + let mut usage_map: HashMap = HashMap::new(); + provider_usage.iter().for_each(|usage| { + usage_map + .entry(usage.model.clone()) + .and_modify(|e| { + e.usage.input_tokens = Some( + e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), + ); + e.usage.output_tokens = Some( + e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), + ); + e.usage.total_tokens = Some( + e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), + ); + if e.cost.is_none() || usage.cost.is_none() { + e.cost = None; // Pricing is not available for all models + } else { + e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); + } + }) + .or_insert_with(|| usage.clone()); + }); + Ok(usage_map.into_values().collect()) } } @@ -507,6 +536,32 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_usage_rollup() -> Result<()> { + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let agent = Agent::new(Box::new(provider)); + + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + while stream.try_next().await?.is_some() {} + + // Second message + let mut stream = agent.reply(&initial_messages).await?; + while stream.try_next().await?.is_some() {} + + let usage = agent.usage().await?; + assert_eq!(usage.len(), 1); // 2 messages rolled up to one usage per model + assert_eq!(usage[0].usage.input_tokens, Some(2)); + assert_eq!(usage[0].usage.output_tokens, Some(2)); + assert_eq!(usage[0].usage.total_tokens, Some(4)); + assert_eq!(usage[0].model, "mock"); + assert_eq!(usage[0].cost, Some(dec!(2))); + Ok(()) + } + #[tokio::test] async fn test_tool_call() -> Result<()> { let mut agent = Agent::new(Box::new(MockProvider::new(vec![ diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 18834d14b7ae..e60eb851f1d1 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -3,6 +3,7 @@ pub mod base; pub mod configs; pub mod databricks; pub mod factory; +pub mod model_pricing; pub mod oauth; pub mod ollama; pub mod openai; diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index c3c5d2cb4e23..9648d6592a53 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -6,8 +6,12 @@ use serde_json::{json, Value}; use std::collections::HashSet; use std::time::Duration; -use super::base::{Provider, ProviderUsageCollector, Usage}; +use super::base::ProviderUsage; +use super::base::{Provider, Usage}; use super::configs::AnthropicProviderConfig; +use super::model_pricing::cost; +use super::model_pricing::model_pricing_for; +use super::utils::get_model; use crate::message::{Message, MessageContent}; use mcp_core::content::Content; use mcp_core::role::Role; @@ -16,7 +20,6 @@ use mcp_core::tool::{Tool, ToolCall}; pub struct AnthropicProvider { client: Client, config: AnthropicProviderConfig, - usage_collector: ProviderUsageCollector, } impl AnthropicProvider { @@ -25,11 +28,7 @@ impl AnthropicProvider { .timeout(Duration::from_secs(600)) // 10 minutes timeout .build()?; - Ok(Self { - client, - config, - usage_collector: ProviderUsageCollector::new(), - }) + Ok(Self { client, config }) } fn get_usage(data: &Value) -> Result { @@ -216,7 +215,7 @@ impl Provider for AnthropicProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, Usage)> { + ) -> Result<(Message, ProviderUsage)> { let anthropic_messages = Self::messages_to_anthropic_spec(messages); let tool_specs = Self::tools_to_anthropic_spec(tools); @@ -261,19 +260,17 @@ impl Provider for AnthropicProvider { // Parse response let message = Self::parse_anthropic_response(response.clone())?; let usage = Self::get_usage(&response)?; - self.usage_collector.add_usage(usage.clone()); + let model = get_model(&response); + let cost = cost(&usage, &model_pricing_for(&model)); - Ok((message, usage)) - } - - fn total_usage(&self) -> Usage { - self.usage_collector.get_usage() + Ok((message, ProviderUsage::new(model, usage, cost))) } } #[cfg(test)] mod tests { use super::*; + use rust_decimal_macros::dec; use serde_json::json; use wiremock::matchers::{header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -309,7 +306,7 @@ mod tests { "type": "text", "text": "Hello! How can I assist you today?" }], - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-latest", "stop_reason": "end_turn", "stop_sequence": null, "usage": { @@ -332,15 +329,11 @@ mod tests { panic!("Expected Text content"); } - assert_eq!(usage.input_tokens, Some(12)); - assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(27)); - - // Check total usage - let total = provider.total_usage(); - assert_eq!(total.input_tokens, Some(12)); - assert_eq!(total.output_tokens, Some(15)); - assert_eq!(total.total_tokens, Some(27)); + assert_eq!(usage.usage.input_tokens, Some(12)); + assert_eq!(usage.usage.output_tokens, Some(15)); + assert_eq!(usage.usage.total_tokens, Some(27)); + assert_eq!(usage.model, "claude-3-5-sonnet-latest"); + assert_eq!(usage.cost, Some(dec!(0.000261))); Ok(()) } @@ -397,9 +390,9 @@ mod tests { panic!("Expected ToolRequest content"); } - assert_eq!(usage.input_tokens, Some(15)); - assert_eq!(usage.output_tokens, Some(20)); - assert_eq!(usage.total_tokens, Some(35)); + assert_eq!(usage.usage.input_tokens, Some(15)); + assert_eq!(usage.usage.output_tokens, Some(20)); + assert_eq!(usage.usage.total_tokens, Some(35)); Ok(()) } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 004b86ee5dce..84dbb78d35b5 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -1,10 +1,30 @@ use anyhow::Result; +use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use std::sync::Mutex; use crate::message::Message; use mcp_core::tool::Tool; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderUsage { + pub model: String, + pub usage: Usage, + pub cost: Option, +} + +impl ProviderUsage { + pub fn new(model: String, usage: Usage, cost: Option) -> Self { + Self { model, usage, cost } + } +} + +#[derive(Debug, Clone)] +pub struct Pricing { + /// Prices are per million tokens. + pub input_token_price: Decimal, + pub output_token_price: Decimal, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Usage { pub input_tokens: Option, @@ -39,56 +59,13 @@ pub trait Provider: Send + Sync { /// * `tools` - Optional list of tools the model can use /// /// # Returns - /// A tuple containing the model's response message and usage statistics + /// A tuple containing the model's response message and provider usage statistics async fn complete( &self, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, Usage)>; - - /// Providers should implement this method to return their total usage statistics from provider.complete (or others) - fn total_usage(&self) -> Usage; -} - -/// A simple struct to reuse for collecting usage statistics for provider implementations. -pub struct ProviderUsageCollector { - usage: Mutex, -} - -impl Default for ProviderUsageCollector { - fn default() -> Self { - Self::new() - } -} - -impl ProviderUsageCollector { - pub fn new() -> Self { - Self { - usage: Mutex::new(Usage::default()), - } - } - - pub fn add_usage(&self, usage: Usage) { - if let Ok(mut current) = self.usage.lock() { - if let Some(input_tokens) = usage.input_tokens { - current.input_tokens = Some(current.input_tokens.unwrap_or(0) + input_tokens); - } - if let Some(output_tokens) = usage.output_tokens { - current.output_tokens = Some(current.output_tokens.unwrap_or(0) + output_tokens); - } - if let Some(total_tokens) = usage.total_tokens { - current.total_tokens = Some(current.total_tokens.unwrap_or(0) + total_tokens); - } - } - } - - pub fn get_usage(&self) -> Usage { - self.usage - .lock() - .map(|guard| guard.clone()) - .unwrap_or_default() - } + ) -> Result<(Message, ProviderUsage)>; } #[cfg(test)] @@ -122,23 +99,4 @@ mod tests { Ok(()) } - - #[test] - fn test_usage_collector() { - let collector = ProviderUsageCollector::new(); - - // Add first usage - collector.add_usage(Usage::new(Some(10), Some(20), Some(30))); - let usage1 = collector.get_usage(); - assert_eq!(usage1.input_tokens, Some(10)); - assert_eq!(usage1.output_tokens, Some(20)); - assert_eq!(usage1.total_tokens, Some(30)); - - // Add second usage - collector.add_usage(Usage::new(Some(5), Some(10), Some(15))); - let usage2 = collector.get_usage(); - assert_eq!(usage2.input_tokens, Some(15)); - assert_eq!(usage2.output_tokens, Some(30)); - assert_eq!(usage2.total_tokens, Some(45)); - } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index f89a7360d363..0f385e6f02ef 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -4,12 +4,13 @@ use reqwest::{Client, StatusCode}; use serde_json::{json, Value}; use std::time::Duration; -use super::base::{Provider, ProviderUsageCollector, Usage}; +use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{DatabricksAuth, DatabricksProviderConfig}; +use super::model_pricing::{cost, model_pricing_for}; use super::oauth; use super::utils::{ - check_bedrock_context_length_error, check_openai_context_length_error, messages_to_openai_spec, - openai_response_to_message, tools_to_openai_spec, + check_bedrock_context_length_error, check_openai_context_length_error, get_model, + messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, }; use crate::message::Message; use mcp_core::tool::Tool; @@ -17,7 +18,6 @@ use mcp_core::tool::Tool; pub struct DatabricksProvider { client: Client, config: DatabricksProviderConfig, - usage_collector: ProviderUsageCollector, } impl DatabricksProvider { @@ -26,11 +26,7 @@ impl DatabricksProvider { .timeout(Duration::from_secs(600)) // 10 minutes timeout .build()?; - Ok(Self { - client, - config, - usage_collector: ProviderUsageCollector::new(), - }) + Ok(Self { client, config }) } async fn ensure_auth_header(&self) -> Result { @@ -114,7 +110,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, Usage)> { + ) -> Result<(Message, ProviderUsage)> { // Prepare messages and tools let messages_spec = messages_to_openai_spec(messages, &self.config.image_format); let tools_spec = if !tools.is_empty() { @@ -167,13 +163,10 @@ impl Provider for DatabricksProvider { // Parse response let message = openai_response_to_message(response.clone())?; let usage = Self::get_usage(&response)?; - self.usage_collector.add_usage(usage.clone()); + let model = get_model(&response); + let cost = cost(&usage, &model_pricing_for(&model)); - Ok((message, usage)) - } - - fn total_usage(&self) -> Usage { - self.usage_collector.get_usage() + Ok((message, ProviderUsage::new(model, usage, cost))) } } @@ -248,13 +241,9 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(reply_usage.total_tokens, Some(35)); - - // Check total usage - let total = provider.total_usage(); - assert_eq!(total.input_tokens, Some(10)); - assert_eq!(total.output_tokens, Some(25)); - assert_eq!(total.total_tokens, Some(35)); + assert_eq!(reply_usage.usage.input_tokens, Some(10)); + assert_eq!(reply_usage.usage.output_tokens, Some(25)); + assert_eq!(reply_usage.usage.total_tokens, Some(35)); Ok(()) } diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index 71ef5fd0e481..aedc3d67648c 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use rust_decimal_macros::dec; use std::sync::Arc; use std::sync::Mutex; @@ -7,6 +8,8 @@ use crate::message::Message; use crate::providers::base::{Provider, Usage}; use mcp_core::tool::Tool; +use super::base::ProviderUsage; + /// A mock provider that returns pre-configured responses for testing pub struct MockProvider { responses: Arc>>, @@ -28,17 +31,20 @@ impl Provider for MockProvider { _system_prompt: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, Usage)> { + ) -> Result<(Message, ProviderUsage)> { let mut responses = self.responses.lock().unwrap(); + let usage = Usage::new(Some(1), Some(1), Some(2)); if responses.is_empty() { // Return empty response if no more pre-configured responses - Ok((Message::assistant().with_text(""), Usage::default())) + Ok(( + Message::assistant().with_text(""), + ProviderUsage::new("mock".to_string(), usage, Some(dec!(1))), + )) } else { - Ok((responses.remove(0), Usage::default())) + Ok(( + responses.remove(0), + ProviderUsage::new("mock".to_string(), usage, Some(dec!(1))), + )) } } - - fn total_usage(&self) -> Usage { - Usage::default() - } } diff --git a/crates/goose/src/providers/model_pricing.rs b/crates/goose/src/providers/model_pricing.rs new file mode 100644 index 000000000000..0c3a15f7debe --- /dev/null +++ b/crates/goose/src/providers/model_pricing.rs @@ -0,0 +1,125 @@ +use std::collections::HashMap; + +use rust_decimal::Decimal; +use rust_decimal_macros::dec; + +use super::base::{Pricing, Usage}; + +lazy_static::lazy_static! { + static ref MODEL_PRICING: HashMap = { + let mut m = HashMap::new(); + // Anthropic + m.insert("claude-3-5-sonnet-latest".to_string(), Pricing { + input_token_price: dec!(3), + output_token_price: dec!(15), + }); + m.insert("claude-3-5-sonnet-20241022".to_string(), Pricing { + input_token_price: dec!(3), + output_token_price: dec!(15), + }); + m.insert("anthropic.claude-3-5-sonnet-20241022-v2:0".to_string(), Pricing { + input_token_price: dec!(3), + output_token_price: dec!(15), + }); + m.insert("claude-3-5-sonnet-20241022-v2:0".to_string(), Pricing { + input_token_price: dec!(3), + output_token_price: dec!(15), + }); + m.insert("claude-3-5-sonnet-v2@20241022".to_string(), Pricing { + input_token_price: dec!(3), + output_token_price: dec!(15), + }); + m.insert("claude-3-5-haiku-latest".to_string(), Pricing { + input_token_price: dec!(0.8), + output_token_price: dec!(4), + }); + m.insert("claude-3-5-haiku-20241022".to_string(), Pricing { + input_token_price: dec!(0.8), + output_token_price: dec!(4), + }); + m.insert("anthropic.claude-3-5-haiku-20241022-v1:0".to_string(), Pricing { + input_token_price: dec!(0.8), + output_token_price: dec!(4), + }); + m.insert("claude-3-5-haiku@20241022".to_string(), Pricing { + input_token_price: dec!(0.8), + output_token_price: dec!(4), + }); + m.insert("claude-3-opus-latest".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(75.00), + }); + m.insert("claude-3-opus-20240229".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(75.00), + }); + m.insert("anthropic.claude-3-opus-20240229-v1:0".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(75.00), + }); + m.insert("claude-3-opus@20240229".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(75.00), + }); + // OpenAI + m.insert("gpt-4o".to_string(), Pricing { + input_token_price: dec!(2.50), + output_token_price: dec!(10.00), + }); + m.insert("gpt-4o-2024-11-20".to_string(), Pricing { + input_token_price: dec!(2.50), + output_token_price: dec!(10.00), + }); + m.insert("gpt-4o-2024-08-06".to_string(), Pricing { + input_token_price: dec!(2.50), + output_token_price: dec!(10.00), + }); + m.insert("gpt-4o-2024-05-13".to_string(), Pricing { + input_token_price: dec!(5.00), + output_token_price: dec!(15.00), + }); + m.insert("gpt-4o-mini".to_string(), Pricing { + input_token_price: dec!(0.150), + output_token_price: dec!(0.600), + }); + m.insert("gpt-4o-mini-2024-07-18".to_string(), Pricing { + input_token_price: dec!(0.150), + output_token_price: dec!(0.600), + }); + m.insert("o1-preview".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(60.00), + }); + m.insert("o1-preview-2024-09-12".to_string(), Pricing { + input_token_price: dec!(15.00), + output_token_price: dec!(60.00), + }); + m.insert("o1-mini".to_string(), Pricing { + input_token_price: dec!(3.00), + output_token_price: dec!(12.00), + }); + m.insert("o1-mini-2024-09-12".to_string(), Pricing { + input_token_price: dec!(3.00), + output_token_price: dec!(12.00), + }); + m + }; +} + +pub fn model_pricing_for(model: &str) -> Option { + MODEL_PRICING.get(model).cloned() +} + +pub fn cost(usage: &Usage, model_pricing: &Option) -> Option { + if let Some(model_pricing) = model_pricing { + let input_price = Decimal::from(usage.input_tokens.unwrap_or(0)) + * model_pricing.input_token_price + / Decimal::from(1_000_000); + let output_price = Decimal::from(usage.output_tokens.unwrap_or(0)) + * model_pricing.output_token_price + / Decimal::from(1_000_000); + Some(input_price + output_price) + } else { + None + } +} diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 129e35bd6df1..36af2f6dd947 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,7 +1,8 @@ -use super::base::{Provider, ProviderUsageCollector, Usage}; +use super::base::{Provider, ProviderUsage, Usage}; use super::configs::OllamaProviderConfig; use super::utils::{ - messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, + get_model, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, + ImageFormat, }; use crate::message::Message; use anyhow::{anyhow, Result}; @@ -18,7 +19,6 @@ pub const OLLAMA_MODEL: &str = "qwen2.5"; pub struct OllamaProvider { client: Client, config: OllamaProviderConfig, - usage_collector: ProviderUsageCollector, } impl OllamaProvider { @@ -27,11 +27,7 @@ impl OllamaProvider { .timeout(Duration::from_secs(600)) // 10 minutes timeout .build()?; - Ok(Self { - client, - config, - usage_collector: ProviderUsageCollector::new(), - }) + Ok(Self { client, config }) } fn get_usage(data: &Value) -> Result { @@ -90,7 +86,7 @@ impl Provider for OllamaProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, Usage)> { + ) -> Result<(Message, ProviderUsage)> { let system_message = json!({ "role": "system", "content": system @@ -131,13 +127,10 @@ impl Provider for OllamaProvider { // Parse response let message = openai_response_to_message(response.clone())?; let usage = Self::get_usage(&response)?; - self.usage_collector.add_usage(usage.clone()); + let model = get_model(&response); + let cost = None; - Ok((message, usage)) - } - - fn total_usage(&self) -> Usage { - self.usage_collector.get_usage() + Ok((message, ProviderUsage::new(model, usage, cost))) } } @@ -207,15 +200,9 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.input_tokens, Some(12)); - assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(27)); - - // Check total usage - let total = provider.total_usage(); - assert_eq!(total.input_tokens, Some(12)); - assert_eq!(total.output_tokens, Some(15)); - assert_eq!(total.total_tokens, Some(27)); + assert_eq!(usage.usage.input_tokens, Some(12)); + assert_eq!(usage.usage.output_tokens, Some(15)); + assert_eq!(usage.usage.total_tokens, Some(27)); Ok(()) } @@ -284,15 +271,9 @@ mod tests { panic!("Expected ToolCall content"); } - assert_eq!(usage.input_tokens, Some(63)); - assert_eq!(usage.output_tokens, Some(70)); - assert_eq!(usage.total_tokens, Some(133)); - - // Check total usage - let total = provider.total_usage(); - assert_eq!(total.input_tokens, Some(63)); - assert_eq!(total.output_tokens, Some(70)); - assert_eq!(total.total_tokens, Some(133)); + assert_eq!(usage.usage.input_tokens, Some(63)); + assert_eq!(usage.usage.output_tokens, Some(70)); + assert_eq!(usage.usage.total_tokens, Some(133)); Ok(()) } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index d62ca7eb048e..7816f12cd6dd 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -5,8 +5,12 @@ use reqwest::StatusCode; use serde_json::{json, Value}; use std::time::Duration; -use super::base::{Provider, ProviderUsageCollector, Usage}; +use super::base::ProviderUsage; +use super::base::{Provider, Usage}; use super::configs::OpenAiProviderConfig; +use super::model_pricing::cost; +use super::model_pricing::model_pricing_for; +use super::utils::get_model; use super::utils::{ check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, @@ -17,7 +21,6 @@ use mcp_core::tool::Tool; pub struct OpenAiProvider { client: Client, config: OpenAiProviderConfig, - usage_collector: ProviderUsageCollector, } impl OpenAiProvider { @@ -26,11 +29,7 @@ impl OpenAiProvider { .timeout(Duration::from_secs(600)) // 10 minutes timeout .build()?; - Ok(Self { - client, - config, - usage_collector: ProviderUsageCollector::new(), - }) + Ok(Self { client, config }) } fn get_usage(data: &Value) -> Result { @@ -96,7 +95,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, Usage)> { + ) -> Result<(Message, ProviderUsage)> { // Not checking for o1 model here since system message is not supported by o1 let system_message = json!({ "role": "system", @@ -155,13 +154,10 @@ impl Provider for OpenAiProvider { // Parse response let message = openai_response_to_message(response.clone())?; let usage = Self::get_usage(&response)?; - self.usage_collector.add_usage(usage.clone()); + let model = get_model(&response); + let cost = cost(&usage, &model_pricing_for(&model)); - Ok((message, usage)) - } - - fn total_usage(&self) -> Usage { - self.usage_collector.get_usage() + Ok((message, ProviderUsage::new(model, usage, cost))) } } @@ -169,6 +165,7 @@ impl Provider for OpenAiProvider { mod tests { use super::*; use crate::message::MessageContent; + use rust_decimal_macros::dec; use serde_json::json; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -213,7 +210,8 @@ mod tests { "prompt_tokens": 12, "completion_tokens": 15, "total_tokens": 27 - } + }, + "model": "gpt-4o" }); let (_, provider) = _setup_mock_server(response_body).await; @@ -232,15 +230,11 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.input_tokens, Some(12)); - assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(27)); - - // Check total usage - let total = provider.total_usage(); - assert_eq!(total.input_tokens, Some(12)); - assert_eq!(total.output_tokens, Some(15)); - assert_eq!(total.total_tokens, Some(27)); + assert_eq!(usage.usage.input_tokens, Some(12)); + assert_eq!(usage.usage.output_tokens, Some(15)); + assert_eq!(usage.usage.total_tokens, Some(27)); + assert_eq!(usage.model, "gpt-4o"); + assert_eq!(usage.cost, Some(dec!(0.00018))); Ok(()) } @@ -312,9 +306,9 @@ mod tests { panic!("Expected ToolCall content"); } - assert_eq!(usage.input_tokens, Some(20)); - assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(35)); + assert_eq!(usage.usage.input_tokens, Some(20)); + assert_eq!(usage.usage.output_tokens, Some(15)); + assert_eq!(usage.usage.total_tokens, Some(35)); Ok(()) } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index b337603faa9c..f300dd041405 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -274,6 +274,19 @@ pub fn check_bedrock_context_length_error(error: &Value) -> Option String { + if let Some(model) = data.get("model") { + if let Some(model_str) = model.as_str() { + model_str.to_string() + } else { + "Unknown".to_string() + } + } else { + "Unknown".to_string() + } +} + #[cfg(test)] mod tests { use super::*;