Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 162 additions & 52 deletions crates/goose/tests/pricing_integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use goose::providers::pricing::{get_model_pricing, initialize_pricing_cache, refresh_pricing};
use std::time::Instant;
use tempfile::TempDir;

#[tokio::test]
async fn test_pricing_cache_performance() {
// Use a unique cache directory for this test to avoid conflicts
let test_cache_dir = format!("/tmp/goose_test_cache_perf_{}", std::process::id());
std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir);
let temp_dir = TempDir::new().unwrap();
std::env::set_var("GOOSE_CACHE_DIR", temp_dir.path());

// Initialize the cache
let start = Instant::now();
Expand All @@ -24,7 +25,7 @@ async fn test_pricing_cache_performance() {
("anthropic", "claude-sonnet-4"),
];

// First fetch (should hit cache)
// First fetch (potentially uncached or cache warming)
let start = Instant::now();
for (provider, model) in &models {
let pricing = get_model_pricing(provider, model).await;
Expand All @@ -42,77 +43,143 @@ async fn test_pricing_cache_performance() {
first_fetch_duration
);

// Second fetch (definitely from cache)
let start = Instant::now();
for (provider, model) in &models {
let pricing = get_model_pricing(provider, model).await;
assert!(
pricing.is_some(),
"Expected pricing for {}/{}",
provider,
model
);
// Run many iterations to test cache performance
const ITERATIONS: u32 = 100;
let mut total_duration = std::time::Duration::ZERO;
let mut min_duration = std::time::Duration::MAX;
let mut max_duration = std::time::Duration::ZERO;

for i in 0..ITERATIONS {
let start = Instant::now();
for (provider, model) in &models {
let pricing = get_model_pricing(provider, model).await;
assert!(
pricing.is_some(),
"Expected pricing for {}/{}",
provider,
model
);
}
let iteration_duration = start.elapsed();
total_duration += iteration_duration;
min_duration = min_duration.min(iteration_duration);
max_duration = max_duration.max(iteration_duration);

// Print progress every 20 iterations
if (i + 1) % 20 == 0 {
println!("Completed {} iterations", i + 1);
}
}
let second_fetch_duration = start.elapsed();
println!(
"Second fetch of {} models took: {:?}",
models.len(),
second_fetch_duration

let avg_duration = total_duration / ITERATIONS;

println!("\nCache performance over {} iterations:", ITERATIONS);
println!(" Average duration: {:?}", avg_duration);
println!(" Min duration: {:?}", min_duration);
println!(" Max duration: {:?}", max_duration);
println!(" First fetch duration: {:?}", first_fetch_duration);

// The average cached fetch should not be slower than the first fetch
// We allow some margin for variance and system load
assert!(
avg_duration <= first_fetch_duration,
"Average cache fetch ({:?}) should not be slower than initial fetch ({:?})",
avg_duration,
first_fetch_duration
);

// Cache fetch should be significantly faster
// Note: Both fetches are already very fast (microseconds), so we just ensure
// the second fetch is not slower than the first (allowing for some variance)
// Also check that eventually (min duration) the cache is faster
// This ensures that after warming up, the cache provides benefit
assert!(
second_fetch_duration <= first_fetch_duration * 2,
"Cache fetch should not be significantly slower than initial fetch. First: {:?}, Second: {:?}",
first_fetch_duration,
second_fetch_duration
min_duration <= first_fetch_duration,
"Best cache performance ({:?}) should be at least as fast as initial fetch ({:?})",
min_duration,
first_fetch_duration
);

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
let _ = std::fs::remove_dir_all(&test_cache_dir);
}

#[tokio::test]
async fn test_pricing_refresh() {
// Use a unique cache directory for this test to avoid conflicts
let test_cache_dir = format!("/tmp/goose_test_cache_refresh_{}", std::process::id());
std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir);
let temp_dir = TempDir::new().unwrap();
std::env::set_var("GOOSE_CACHE_DIR", temp_dir.path());

const MAX_RETRIES: u32 = 5;
let mut attempt = 0;
let mut last_error = None;

while attempt < MAX_RETRIES {
attempt += 1;
println!("Attempt {} of {}", attempt, MAX_RETRIES);

// Try to run the test
match run_pricing_refresh_test().await {
Ok(_) => {
println!("Test passed on attempt {}", attempt);
break;
}
Err(e) => {
println!("Attempt {} failed: {}", attempt, e);
last_error = Some(e);

if attempt < MAX_RETRIES {
println!("Retrying in 1 second...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
}

// If all attempts failed, panic with the last error
if attempt == MAX_RETRIES && last_error.is_some() {
panic!(
"Test failed after {} attempts. Last error: {}",
MAX_RETRIES,
last_error.unwrap()
);
}

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
}

async fn run_pricing_refresh_test() -> Result<(), String> {
// Initialize first
initialize_pricing_cache()
.await
.expect("Failed to initialize pricing cache");
.map_err(|e| format!("Failed to initialize pricing cache: {}", e))?;

// Get initial pricing (using a model that actually exists)
let initial_pricing = get_model_pricing("anthropic", "claude-3.5-sonnet").await;
assert!(initial_pricing.is_some(), "Expected initial pricing");
if initial_pricing.is_none() {
return Err("Expected initial pricing but got None".to_string());
}

// Force refresh
let start = Instant::now();
refresh_pricing().await.expect("Failed to refresh pricing");
refresh_pricing()
.await
.map_err(|e| format!("Failed to refresh pricing: {}", e))?;
let refresh_duration = start.elapsed();
println!("Pricing refresh took: {:?}", refresh_duration);

// Get pricing after refresh
let refreshed_pricing = get_model_pricing("anthropic", "claude-3.5-sonnet").await;
assert!(
refreshed_pricing.is_some(),
"Expected pricing after refresh"
);
if refreshed_pricing.is_none() {
return Err("Expected pricing after refresh but got None".to_string());
}

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
let _ = std::fs::remove_dir_all(&test_cache_dir);
Ok(())
}

#[tokio::test]
async fn test_model_not_in_openrouter() {
// Use a unique cache directory for this test to avoid conflicts
let test_cache_dir = format!("/tmp/goose_test_cache_model_{}", std::process::id());
std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir);
let temp_dir = TempDir::new().unwrap();
std::env::set_var("GOOSE_CACHE_DIR", temp_dir.path());

initialize_pricing_cache()
.await
Expand All @@ -127,20 +194,60 @@ async fn test_model_not_in_openrouter() {

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
let _ = std::fs::remove_dir_all(&test_cache_dir);
// TempDir automatically cleans up when dropped
}

#[tokio::test]
async fn test_concurrent_access() {
use tokio::task;

// Use a unique cache directory for this test to avoid conflicts
let test_cache_dir = format!("/tmp/goose_test_cache_concurrent_{}", std::process::id());
std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir);
let temp_dir = TempDir::new().unwrap();
std::env::set_var("GOOSE_CACHE_DIR", temp_dir.path());

const MAX_RETRIES: u32 = 5;
let mut attempt = 0;
let mut last_error = None;

while attempt < MAX_RETRIES {
attempt += 1;
println!("Attempt {} of {}", attempt, MAX_RETRIES);

// Try to run the test
match run_concurrent_access_test().await {
Ok(_) => {
println!("Test passed on attempt {}", attempt);
break;
}
Err(e) => {
println!("Attempt {} failed: {}", attempt, e);
last_error = Some(e);

if attempt < MAX_RETRIES {
println!("Retrying in 1 second...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
}

// If all attempts failed, panic with the last error
if attempt == MAX_RETRIES && last_error.is_some() {
panic!(
"Test failed after {} attempts. Last error: {}",
MAX_RETRIES,
last_error.unwrap()
);
}

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
}

async fn run_concurrent_access_test() -> Result<(), String> {
use tokio::task;

initialize_pricing_cache()
.await
.expect("Failed to initialize pricing cache");
.map_err(|e| format!("Failed to initialize pricing cache: {}", e))?;

// Spawn multiple tasks to access pricing concurrently
let mut handles = vec![];
Expand All @@ -156,13 +263,16 @@ async fn test_concurrent_access() {
}

// Wait for all tasks
for handle in handles {
let (task_id, has_pricing, duration) = handle.await.unwrap();
assert!(has_pricing, "Task {} should have gotten pricing", task_id);
for (idx, handle) in handles.into_iter().enumerate() {
let (task_id, has_pricing, duration) = handle
.await
.map_err(|e| format!("Task {} panicked: {}", idx, e))?;

if !has_pricing {
return Err(format!("Task {} should have gotten pricing", task_id));
}
println!("Task {} took: {:?}", task_id, duration);
}

// Clean up
std::env::remove_var("GOOSE_CACHE_DIR");
let _ = std::fs::remove_dir_all(&test_cache_dir);
Ok(())
}
Loading