Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion crates/goose-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ reqwest = "0.11.27"
rand = "0.8.5"
async-trait = "0.1"
rustyline = "15.0.0"
rust_decimal = "1.36.0"
rust_decimal_macros = "1.36.0"

[dev-dependencies]
tempfile = "3"
temp-env = "0.3.6"
temp-env = { version = "0.3.6", features = ["async_closure"] }

8 changes: 4 additions & 4 deletions crates/goose-cli/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{
agent::Agent as GooseAgent, message::Message, providers::base::Usage, systems::System,
agent::Agent as GooseAgent, message::Message, providers::base::ProviderUsage, systems::System,
};

#[async_trait]
pub trait Agent {
fn add_system(&mut self, system: Box<dyn System>);
async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>>;
fn total_usage(&self) -> Usage;
async fn usage(&self) -> Result<Vec<ProviderUsage>>;
}

#[async_trait]
Expand All @@ -22,7 +22,7 @@ impl Agent for GooseAgent {
self.reply(messages).await
}

fn total_usage(&self) -> Usage {
self.total_usage()
async fn usage(&self) -> Result<Vec<ProviderUsage>> {
self.usage().await
}
}
12 changes: 9 additions & 3 deletions crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::vec;

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{message::Message, systems::System};
use goose::{message::Message, providers::base::ProviderUsage, systems::System};

use crate::agents::agent::Agent;

Expand All @@ -15,7 +17,11 @@ impl Agent for MockAgent {
Ok(Box::pin(futures::stream::empty()))
}

fn total_usage(&self) -> goose::providers::base::Usage {
goose::providers::base::Usage::default()
async fn usage(&self) -> Result<Vec<ProviderUsage>> {
Ok(vec![ProviderUsage::new(
"mock".to_string(),
Default::default(),
None,
)])
}
}
24 changes: 16 additions & 8 deletions crates/goose-cli/src/log_usage.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use goose::providers::base::Usage;
use goose::providers::base::ProviderUsage;

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct SessionLog {
session_file: String,
usage: goose::providers::base::Usage,
usage: Vec<ProviderUsage>,
}

pub fn log_usage(session_file: String, usage: Usage) {
pub fn log_usage(session_file: String, usage: Vec<ProviderUsage>) {
let log = SessionLog {
session_file,
usage,
Expand Down Expand Up @@ -49,12 +49,14 @@ pub fn log_usage(session_file: String, usage: Usage) {

#[cfg(test)]
mod tests {
use goose::providers::base::Usage;
use goose::providers::base::{ProviderUsage, Usage};
use rust_decimal_macros::dec;

use crate::{
log_usage::{log_usage, SessionLog},
test_helpers::run_with_tmp_dir,
};

#[test]
fn test_session_logging() {
run_with_tmp_dir(|| {
Expand All @@ -63,7 +65,11 @@ mod tests {

log_usage(
"path.txt".to_string(),
Usage::new(Some(10), Some(20), Some(30)),
vec![ProviderUsage::new(
"model".to_string(),
Usage::new(Some(10), Some(20), Some(30)),
Some(dec!(0.5)),
)],
);

// Check if log file exists and contains the expected content
Expand All @@ -75,9 +81,11 @@ mod tests {
serde_json::from_str(log_content.lines().last().unwrap()).unwrap();

assert!(log.session_file.contains("path.txt"));
assert_eq!(log.usage.input_tokens, Some(10));
assert_eq!(log.usage.output_tokens, Some(20));
assert_eq!(log.usage.total_tokens, Some(30));
assert_eq!(log.usage[0].usage.input_tokens, Some(10));
assert_eq!(log.usage[0].usage.output_tokens, Some(20));
assert_eq!(log.usage[0].usage.total_tokens, Some(30));
assert_eq!(log.usage[0].model, "model");
assert_eq!(log.usage[0].cost, Some(dec!(0.5)));
})
}
}
34 changes: 14 additions & 20 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl<'a> Session<'a> {
self.agent_process_messages().await;
self.prompt.hide_busy();
}
self.close_session();
self.close_session().await;
Ok(())
}

Expand All @@ -162,7 +162,7 @@ impl<'a> Session<'a> {

self.agent_process_messages().await;

self.close_session();
self.close_session().await;
Ok(())
}

Expand Down Expand Up @@ -312,7 +312,7 @@ We've removed the conversation up to the most recent user message
self.agent.add_system(goosehints_system);
}

fn close_session(&mut self) {
async fn close_session(&mut self) {
self.prompt.render(raw_message(
format!(
"Closing session. Recorded to {}\n",
Expand All @@ -321,22 +321,17 @@ We've removed the conversation up to the most recent user message
.as_str(),
));
self.prompt.close();
match self.agent.usage().await {
Ok(usage) => log_usage(self.session_file.to_string_lossy().to_string(), usage),
Err(e) => eprintln!("Failed to collect total provider usage: {}", e),
}
}

pub fn session_file(&self) -> PathBuf {
self.session_file.clone()
}
}

impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
log_usage(
self.session_file.to_string_lossy().to_string(),
self.agent.total_usage(),
);
}
}

fn raw_message(content: &str) -> Box<Message> {
Box::new(Message::assistant().with_text(content))
}
Expand All @@ -348,7 +343,7 @@ mod tests {

use crate::agents::mock_agent::MockAgent;
use crate::prompt::{self, Input};
use crate::test_helpers::run_with_tmp_dir;
use crate::test_helpers::{run_with_tmp_dir, run_with_tmp_dir_async};

use super::*;
use goose::errors::AgentResult;
Expand Down Expand Up @@ -808,19 +803,17 @@ mod tests {
})
}

#[test]
fn test_session_logging() -> Result<()> {
run_with_tmp_dir(|| {
#[tokio::test]
async fn test_session_logging() -> Result<()> {
run_with_tmp_dir_async(|| async {
// Create a test session
let session = create_test_session();
let mut session = create_test_session();
let session_file = session.session_file.clone();
// Create a log directory
let home_dir = dirs::home_dir().unwrap();
let log_dir = home_dir.join(".config").join("goose").join("logs");
std::fs::create_dir_all(&log_dir)?;

// Drop the session to trigger logging
drop(session);
session.close_session().await;

// Check if log file exists and contains the expected content
let log_file = log_dir.join("goose.log");
Expand All @@ -834,6 +827,7 @@ mod tests {

Ok(())
})
.await
}

fn assert_last_prompt_text(session: &Session, expected_text: &str) {
Expand Down
58 changes: 41 additions & 17 deletions crates/goose-cli/src/test_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,52 @@
/// Helper function to set up a temporary home directory for testing, returns path of that temp dir.
/// Also creates a default profiles.json to avoid obscure test failures when there are no profiles.
#[cfg(test)]
pub fn run_with_tmp_dir<F: FnOnce() -> T, T>(func: F) -> T {
use std::ffi::OsStr;
use std::fs;
use tempfile::tempdir;

// Helper function to set up a temporary home directory for testing, returns path of that temp dir.
// Also creates a default profiles.json to avoid obscure test failures when there are no profiles.

let temp_dir = tempdir().unwrap();
// std::env::set_var("HOME", temp_dir.path());
let temp_dir_path = temp_dir.path().to_path_buf();
setup_profile(&temp_dir_path);

temp_env::with_vars(
[
("HOME", Some(temp_dir_path.as_os_str())),
("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))),
],
func,
)
}

#[cfg(test)]
pub async fn run_with_tmp_dir_async<F, Fut, T>(func: F) -> T
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = T>,
{
use std::ffi::OsStr;
use tempfile::tempdir;

let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
println!(
"Created temporary home directory: {}",
temp_dir_path.display()
);
setup_profile(&temp_dir_path);

temp_env::async_with_vars(
[
("HOME", Some(temp_dir_path.as_os_str())),
("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))),
],
func(),
)
.await
}

#[cfg(test)]
use std::path::PathBuf;
#[cfg(test)]
fn setup_profile(temp_dir_path: &PathBuf) {
use std::fs;

let profile_path = temp_dir_path
.join(".config")
.join("goose")
Expand All @@ -31,12 +63,4 @@ pub fn run_with_tmp_dir<F: FnOnce() -> T, T>(func: F) -> T {
}
}"#;
fs::write(&profile_path, profile).unwrap();

temp_env::with_vars(
[
("HOME", Some(temp_dir_path.as_os_str())),
("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))),
],
func,
)
}
13 changes: 6 additions & 7 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ mod tests {
use super::*;
use goose::{
agent::Agent,
providers::{base::Provider, configs::OpenAiProviderConfig},
providers::{
base::{Provider, ProviderUsage, Usage},
configs::OpenAiProviderConfig,
},
};
use mcp_core::tool::Tool;

Expand All @@ -406,16 +409,12 @@ mod tests {
_system_prompt: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, goose::providers::base::Usage), anyhow::Error> {
) -> Result<(Message, ProviderUsage), anyhow::Error> {
Ok((
Message::assistant().with_text("Mock response"),
goose::providers::base::Usage::default(),
ProviderUsage::new("mock".to_string(), Usage::default(), None),
))
}

fn total_usage(&self) -> goose::providers::base::Usage {
goose::providers::base::Usage::default()
}
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ kill_tree = "0.2.4"

keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] }
shellexpand = "3.1.0"
rust_decimal = "1.36.0"
rust_decimal_macros = "1.36.0"

[dev-dependencies]
sysinfo = "0.32.1"
Expand Down
6 changes: 3 additions & 3 deletions crates/goose/examples/databricks_oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ async fn main() -> Result<()> {
}
println!("\nToken Usage:");
println!("------------");
println!("Input tokens: {:?}", usage.input_tokens);
println!("Output tokens: {:?}", usage.output_tokens);
println!("Total tokens: {:?}", usage.total_tokens);
println!("Input tokens: {:?}", usage.usage.input_tokens);
println!("Output tokens: {:?}", usage.usage.output_tokens);
println!("Total tokens: {:?}", usage.usage.total_tokens);

Ok(())
}
6 changes: 3 additions & 3 deletions crates/goose/examples/image_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ async fn main() -> Result<()> {
}
println!("\nToken Usage:");
println!("------------");
println!("Input tokens: {:?}", usage.input_tokens);
println!("Output tokens: {:?}", usage.output_tokens);
println!("Total tokens: {:?}", usage.total_tokens);
println!("Input tokens: {:?}", usage.usage.input_tokens);
println!("Output tokens: {:?}", usage.usage.output_tokens);
println!("Total tokens: {:?}", usage.usage.total_tokens);
}

Ok(())
Expand Down
Loading
Loading