diff --git a/.gitignore b/.gitignore index caab83d726c7..35892eaff54d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ tokenizer_files/ .idea *.log tmp/ +**/*.env # Generated by Cargo # will have compiled files and executables diff --git a/Cargo.lock b/Cargo.lock index 9541e0b7dbf9..0b9b3efb536c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3597,11 +3597,14 @@ dependencies = [ "criterion", "ctor", "dotenv", + "futures", "include_dir", "indoc 1.0.9", "lazy_static", + "md5", "minijinja", "once_cell", + "parking_lot", "regex", "reqwest 0.12.12", "serde", @@ -5465,6 +5468,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "measure_time" version = "0.8.3" diff --git a/bindings/kotlin/CONNECTION_POOLING.md b/bindings/kotlin/CONNECTION_POOLING.md new file mode 100644 index 000000000000..babc549c6634 --- /dev/null +++ b/bindings/kotlin/CONNECTION_POOLING.md @@ -0,0 +1,163 @@ +# Connection Pooling for Goose LLM in Kotlin + +This document explains how to use connection pooling with goose-llm in your Kotlin application to improve performance when making many parallel requests. + +## Overview + +Connection pooling reuses provider connections instead of creating a new one for each request. This provides significant performance improvements when: + +1. Making many parallel requests +2. Processing numerous sequential requests +3. Handling high-throughput agent loops + +## Important Note + +The current implementation has been updated for type safety for compatibility with Rust's UniFFI bindings. If you encounter any issues, please check that: + +1. You're using the latest version of the Kotlin bindings +2. The values passed to `configureProviderPool` are within appropriate ranges for `u32` types + +## Basic Usage + +### 1. Initialize and Configure the Pool + +Initialize the provider pool at application startup: + +```kotlin +import uniffi.goose_llm.* + +// Initialize with default settings (10 max connections, 5 min idle timeout) +initProviderPool() + +// Or with custom configuration +configureProviderPool( + maxSize = 20, // Maximum number of connections in the pool + maxIdleSec = 300, // Maximum idle time (seconds) before cleanup + maxLifetimeSec = 3600, // Maximum lifetime (seconds) for a connection + maxUses = 100 // Maximum number of uses before recycling +) +``` + +### 2. Create Completion Requests with Pool Option + +```kotlin +val request = createCompletionRequest( + providerName = "databricks", + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = "You are a helpful assistant.", + messages = messages, + extensions = extensions, + usePool = true // Enable connection pooling (default is true) +) + +// Process the completion +val response = completion(request) +``` + +### 3. Monitor Pool Statistics + +```kotlin +// Get pool statistics as a string +val stats = getPoolStats() +println(stats) + +// Example output: +// Pool: openai:d41d8cd98f00b204e9800998ecf8427e:gpt-4o +// Created: 5 +// Borrowed: 15 +// Returned: 15 +// Errors: 0 +// Max Pool Size: 10 +// Current Pool Size: 5 +// Waiting: 0 +``` + +## Advanced Usage + +### Parallel Completion Service + +The `ParallelCompletionService` class provides a high-level wrapper for processing multiple messages in parallel with connection pooling: + +```kotlin +val service = ParallelCompletionService( + providerName = "databricks", + providerConfig = providerConfig, + modelConfig = modelConfig, + maxConcurrency = 5, + usePool = true +) + +// Process multiple message lists in parallel +val responses = service.processInParallel(messageListsToProcess) +``` + +### Agent Service + +For production agent services, use the `AgentService` class: + +```kotlin +val agentService = AgentService( + providerName = "databricks", + providerConfig = providerConfig, + modelName = "goose-gpt-4-1", + maxConcurrentAgents = 10, + useConnectionPooling = true +) + +// Process agent requests in parallel +val responses = agentService.processMessagesInParallel(requests) + +// Get service metrics +val metrics = agentService.getMetrics() +``` + +## Performance Benchmarking + +Use the `PoolPerformanceTester` to benchmark your specific workload: + +```kotlin +val tester = PoolPerformanceTester( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig +) + +val result = tester.runBenchmark( + iterations = 10, // Number of requests to make + parallelism = 5, // Max parallel requests + messageSupplier = { /* create test messages */ } +) + +println(result) +``` + +## Connection Pool Configuration Recommendations + +| Scenario | Max Size | Idle Timeout | Lifetime | Max Uses | +|----------|----------|--------------|----------|----------| +| Low volume | 5 | 300s (5min) | 3600s (1hr) | 50 | +| Medium volume | 10-20 | 600s (10min) | 7200s (2hr) | 100 | +| High volume | 30-50 | 900s (15min) | 10800s (3hr) | 200 | +| Batch processing | 50-100 | 300s (5min) | 3600s (1hr) | 50 | + +## Troubleshooting + +If you encounter issues: + +1. Check pool statistics to diagnose connection usage +2. Ensure the pool size is appropriate for your concurrency level +3. Try increasing timeout values if connections are being recycled too often +4. Fall back to non-pooled connections by setting `usePool = false` to compare behavior + +## Notes and Limitations + +- The pool uses separate connection groups per provider config and model +- Connection errors are handled by creating new connections +- Provider maintenance is automatic with periodic cleanup of idle connections +- Metrics are available via `getPoolStats()` for monitoring + +For more examples, see: +- `AgentService.kt` - Production service with connection pooling +- `ProviderPool.kt` - Kotlin wrapper for the provider pool +- `PoolingDemo.kt` - Interactive demo and benchmark \ No newline at end of file diff --git a/bindings/kotlin/CONNECTION_POOLING_POC.md b/bindings/kotlin/CONNECTION_POOLING_POC.md new file mode 100644 index 000000000000..a02ad7089757 --- /dev/null +++ b/bindings/kotlin/CONNECTION_POOLING_POC.md @@ -0,0 +1,244 @@ +# Connection Pooling Proof of Concept for Goose LLM + +This document outlines a proof-of-concept implementation for adding connection pooling to the goose-llm FFI bindings in Kotlin. + +## Architecture Overview + +The connection pooling system would consist of: + +1. **Provider Pool in Rust**: + - Thread-safe pool of provider instances + - Configurable pool size, idle timeout, and connection limits + - Auto-cleanup of idle connections + +2. **Kotlin Bindings**: + - Direct access to pool configuration functions + - Connection pool statistics + - Integration with completion requests + +## Implementation + +### Rust Side (crates/goose-llm) + +1. Create `ProviderPool` in `providers/pool.rs`: +```rust +pub struct ProviderPool { + providers: Arc>>>, + config: PoolConfig, +} + +pub struct PoolConfig { + pub max_size: u32, + pub max_idle_seconds: u64, + pub max_lifetime_seconds: u64, + pub max_uses: u32, +} + +impl ProviderPool { + // Get or create a provider + pub async fn get_provider(&self, name: &str, config: serde_json::Value, model: ModelConfig) -> Result, ProviderError>; + + // Return a provider to the pool + pub fn return_provider(&self, provider: Arc); + + // Clean up idle providers + pub fn cleanup_idle(&self); + + // Get pool statistics + pub fn get_stats(&self) -> PoolStats; +} +``` + +2. Add global pool manager in `completion.rs`: +```rust +// Initialize the pool +#[uniffi::export] +pub fn init_provider_pool() { + // Initialize the global provider pool + let _ = PROVIDER_POOL.get_or_init(|| Arc::new(ProviderPool::new(PoolConfig::default()))); +} + +// Configure the pool +#[uniffi::export] +pub fn configure_provider_pool( + max_size: u32, + max_idle_seconds: u64, + max_lifetime_seconds: u64, + max_uses: u32, +) { + let config = PoolConfig { + max_size, + max_idle_seconds, + max_lifetime_seconds, + max_uses, + }; + + if let Some(pool) = PROVIDER_POOL.get() { + pool.update_config(config); + } else { + let _ = PROVIDER_POOL.get_or_init(|| Arc::new(ProviderPool::new(config))); + } +} + +// Get statistics about the provider pool +#[uniffi::export] +pub fn get_pool_stats() -> String { + if let Some(pool) = PROVIDER_POOL.get() { + format!("{:?}", pool.get_stats()) + } else { + "Provider pool not initialized".into() + } +} +``` + +3. Update `CompletionRequest` to include a pool option: +```rust +pub struct CompletionRequest { + // existing fields... + pub use_pool: Option, +} +``` + +4. Update `completion` function to use pooled providers: +```rust +#[uniffi::export(async_runtime = "tokio")] +pub async fn completion(req: CompletionRequest) -> Result { + // Check if we should use pooling + let use_pool = req.use_pool.unwrap_or(true); + + let provider = if use_pool && PROVIDER_POOL.get().is_some() { + // Get provider from pool + PROVIDER_POOL + .get() + .unwrap() + .get_provider(&req.provider_name, req.provider_config.clone(), req.model_config.clone()) + .await + .map_err(|e| CompletionError::Provider(e))? + } else { + // Create provider directly + create_provider(&req.provider_name, req.provider_config.clone(), req.model_config.clone()) + .map_err(|_| CompletionError::UnknownProvider(req.provider_name.clone()))? + }; + + // Rest of completion function... +} +``` + +### Kotlin Side (bindings/kotlin) + +1. Add provider pool functions in `uniffi_goose_llm.kt`: +```kotlin +/** + * Initialize the provider pool with default configuration + */ +fun initProviderPool() { + // FFI call to init_provider_pool +} + +/** + * Configure the provider pool with custom settings + */ +fun configureProviderPool( + maxSize: Int, + maxIdleSeconds: Long, + maxLifetimeSeconds: Long, + maxUses: Int +) { + // FFI call to configure_provider_pool +} + +/** + * Get statistics about the provider pool + */ +fun getPoolStats(): String { + // FFI call to get_pool_stats +} + +/** + * Create a completion request with optional pool setting + */ +fun createCompletionRequest( + // existing parameters + usePool: Boolean? = null +): CompletionRequest { + // Create request with usePool parameter +} +``` + +2. Create helper class for managing pool: +```kotlin +class ProviderPool { + companion object { + /** + * Initialize the provider pool with default settings + */ + fun initialize() { + initProviderPool() + } + + /** + * Configure the provider pool with custom settings + */ + fun configure(maxSize: Int = 10, maxIdleSeconds: Long = 300, + maxLifetimeSeconds: Long = 3600, maxUses: Int = 100) { + configureProviderPool(maxSize, maxIdleSeconds, maxLifetimeSeconds, maxUses) + } + + /** + * Get statistics about the current connection pool + */ + fun stats(): String { + return getPoolStats() + } + } +} +``` + +## Usage Example + +```kotlin +// Initialize and configure the pool +ProviderPool.initialize() +ProviderPool.configure(maxSize = 20, maxIdleSeconds = 300) + +// Create a completion request with pooling +val request = createCompletionRequest( + providerName = "openai", + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = "You are a helpful assistant.", + messages = messages, + extensions = emptyList(), + usePool = true // Enable connection pooling +) + +// Process the completion +val response = completion(request) + +// Get pool statistics +println(ProviderPool.stats()) +``` + +## Performance Benefits + +The connection pooling implementation would provide the following benefits: + +1. **Reduced Latency**: By reusing existing connections, we eliminate the overhead of creating new providers for each request. + +2. **Higher Throughput**: More efficient connection handling means more requests can be processed in parallel. + +3. **Less Resource Usage**: Fewer connections means less memory usage and fewer file descriptors. + +4. **Improved Stability**: Better handling of connection limits helps avoid rate limiting issues. + +## Next Steps + +To implement this proof of concept: + +1. Add the pool implementation to the Rust code +2. Add FFI exports for pool functions +3. Update the Kotlin bindings to expose the pool functions +4. Create utility classes for Kotlin developers +5. Write example code and documentation + +This approach should provide significant performance improvements for high-throughput scenarios while maintaining backward compatibility with existing code. \ No newline at end of file diff --git a/bindings/kotlin/example/AgentService.kt b/bindings/kotlin/example/AgentService.kt new file mode 100644 index 000000000000..e4676c6c9a16 --- /dev/null +++ b/bindings/kotlin/example/AgentService.kt @@ -0,0 +1,239 @@ +import kotlinx.coroutines.* +import uniffi.goose_llm.* +import java.time.Instant +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import kotlin.time.Duration.Companion.seconds + +/** + * Production-ready agent service with connection pooling + * + * This service is designed for high-throughput parallel agent processing + * using the connection pool for optimal performance. + */ +class AgentService( + private val providerName: String, + private val providerConfig: String, + private val modelName: String, + private val maxConcurrentAgents: Int = 20, + private val useConnectionPooling: Boolean = true +) { + // Service state + private val sessionMap = ConcurrentHashMap() + private val activeRequests = AtomicInteger(0) + private val completionScope = CoroutineScope(Dispatchers.Default) + + // Performance metrics + private val totalRequests = AtomicInteger(0) + private val totalErrors = AtomicInteger(0) + private val totalResponseTimeMs = AtomicInteger(0) + + // Initialize with custom pool settings + init { + // Configure connection pool for optimal parallelism + ProviderPool.initialize() + ProviderPool.configure( + maxSize = maxConcurrentAgents, + maxIdleSec = 300, // 5 minutes idle timeout + maxLifetimeSec = 1800, // 30 minutes lifetime + maxUses = 100 // Max uses per connection + ) + } + + /** + * Process an agent message in a specific session + */ + suspend fun processMessage(sessionId: String, userMessage: String): AgentResponse = withContext(Dispatchers.IO) { + val startTime = System.currentTimeMillis() + activeRequests.incrementAndGet() + + try { + totalRequests.incrementAndGet() + + // Get or create session + val session = sessionMap.computeIfAbsent(sessionId) { + AgentSession(sessionId) + } + + // Process the message + val response = session.addUserMessage(userMessage) + .processWithPooling(useConnectionPooling) + + // Update metrics + val elapsed = System.currentTimeMillis() - startTime + totalResponseTimeMs.addAndGet(elapsed.toInt()) + + response + } catch (e: Exception) { + totalErrors.incrementAndGet() + AgentResponse( + text = "Error: ${e.message}", + processingTimeMs = System.currentTimeMillis() - startTime, + error = true + ) + } finally { + activeRequests.decrementAndGet() + } + } + + /** + * Process multiple agent messages in parallel + */ + suspend fun processMessagesInParallel(requests: List): List = coroutineScope { + requests.map { request -> + async(Dispatchers.IO.limitedParallelism(maxConcurrentAgents)) { + processMessage(request.sessionId, request.message) + } + }.awaitAll() + } + + /** + * Get current service metrics + */ + fun getMetrics(): ServiceMetrics { + val totalReq = totalRequests.get() + return ServiceMetrics( + totalRequests = totalReq, + totalErrors = totalErrors.get(), + activeRequests = activeRequests.get(), + avgResponseTimeMs = if (totalReq > 0) totalResponseTimeMs.get() / totalReq else 0, + poolStats = ProviderPool.stats() + ) + } + + /** + * Session handling for agent conversations + */ + inner class AgentSession(val sessionId: String) { + private val messages = mutableListOf() + private val createdAt = Instant.now() + + fun addUserMessage(text: String): AgentSession { + messages.add( + Message( + role = Role.USER, + created = System.currentTimeMillis() / 1000, + content = listOf(MessageContent.Text(TextContent(text))) + ) + ) + return this + } + + suspend fun processWithPooling(usePool: Boolean): AgentResponse { + val startTime = System.currentTimeMillis() + + // Create model config + val modelConfig = ModelConfig( + modelName = modelName, + contextLimit = 100000u, + temperature = 0.7f, + maxTokens = 1000 + ) + + // Create completion request (with or without pooling) + val request = createPooledCompletionRequest( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = "You are a helpful assistant.", + messages = messages, + extensions = emptyList(), + usePool = usePool + ) + + // Process the completion + val result = completion(request) + + // Add assistant response to session history + messages.add(result.message) + + // Extract text content from the response + val responseText = result.message.content + .filterIsInstance() + .joinToString("\n") { it.v1.text } + + return AgentResponse( + text = responseText, + processingTimeMs = System.currentTimeMillis() - startTime, + error = false + ) + } + } +} + +/** + * Agent request for batch processing + */ +data class AgentRequest( + val sessionId: String, + val message: String +) + +/** + * Agent response with metrics + */ +data class AgentResponse( + val text: String, + val processingTimeMs: Long, + val error: Boolean +) + +/** + * Service metrics for monitoring + */ +data class ServiceMetrics( + val totalRequests: Int, + val totalErrors: Int, + val activeRequests: Int, + val avgResponseTimeMs: Int, + val poolStats: String +) + +/** + * Example usage of the AgentService + */ +suspend fun main() = coroutineScope { + // Setup provider config + val providerName = System.getenv("PROVIDER_NAME") ?: "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerConfig = """{"host": "$host", "token": "$token"}""" + val modelName = System.getenv("MODEL_NAME") ?: "goose-gpt-4-1" + + // Create agent service + val agentService = AgentService( + providerName = providerName, + providerConfig = providerConfig, + modelName = modelName, + maxConcurrentAgents = 5, + useConnectionPooling = true // Enable connection pooling + ) + + // Process a few requests in parallel + val responses = agentService.processMessagesInParallel(listOf( + AgentRequest("session1", "What is connection pooling?"), + AgentRequest("session2", "Explain the benefits of async programming"), + AgentRequest("session3", "How does Kotlin coroutines work?"), + AgentRequest("session1", "How does it compare to thread pools?"), + AgentRequest("session2", "What about error handling?") + )) + + // Print responses + responses.forEachIndexed { i, response -> + println("Response ${i+1}:") + println("Text: ${response.text.take(100)}...") + println("Processing time: ${response.processingTimeMs}ms") + println("Error: ${response.error}") + println("-----") + } + + // Print metrics + val metrics = agentService.getMetrics() + println("Service Metrics:") + println(" Total requests: ${metrics.totalRequests}") + println(" Total errors: ${metrics.totalErrors}") + println(" Active requests: ${metrics.activeRequests}") + println(" Average response time: ${metrics.avgResponseTimeMs}ms") + println("\nPool Stats:") + println(metrics.poolStats) +} \ No newline at end of file diff --git a/bindings/kotlin/example/ConnectionPoolingPlan.kt b/bindings/kotlin/example/ConnectionPoolingPlan.kt new file mode 100644 index 000000000000..5c41bd1c0d9d --- /dev/null +++ b/bindings/kotlin/example/ConnectionPoolingPlan.kt @@ -0,0 +1,138 @@ +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import uniffi.goose_llm.* + +/** + * This example shows a minimal proof-of-concept for implementing connection pooling + * with the existing code, without requiring Rust-side changes. + */ +fun main() = runBlocking { + println("Connection Pooling for goose-llm") + println("================================") + println("This is a simple proof of concept for connection pooling.") + println("Here's how we can implement connection pooling with minimal changes:") + + // Define our provider pool in Kotlin using a simple cache + val providerCache = ProviderCache() + + println("\n1. Making parallel requests WITHOUT connection pooling...") + val nonPooledResults = runParallelRequests(3, useCache = false) + + println("\n2. Making parallel requests WITH connection pooling...") + val pooledResults = runParallelRequests(3, useCache = true) + + // Calculate and print statistics + val avgNonPooled = nonPooledResults.average() + val avgPooled = pooledResults.average() + val improvement = (avgNonPooled - avgPooled) * 100 / avgNonPooled + + println("\nResults:") + println("- Average time without pooling: ${String.format("%.2f", avgNonPooled)}ms") + println("- Average time with pooling: ${String.format("%.2f", avgPooled)}ms") + println("- Performance improvement: ${String.format("%.2f", improvement)}%") + + println("\nProvider Cache Statistics:") + println("- Created: ${providerCache.created}") + println("- Retrieved: ${providerCache.retrieved}") + println("- Current Size: ${providerCache.size()}") +} + +/** + * Run multiple parallel requests and record the timings + */ +suspend fun runParallelRequests(count: Int, useCache: Boolean): List = coroutineScope { + // Create tasks + val tasks = (1..count).map { id -> + async { + val startTime = System.currentTimeMillis() + println(" Starting request $id...") + + // Make the request + try { + // In reality this would call the provider with completion + simulateProviderUsage(useCache) + + val duration = System.currentTimeMillis() - startTime + println(" Completed request $id in ${duration}ms") + duration + } catch (e: Exception) { + println(" Request $id failed: ${e.message}") + -1L + } + } + } + + // Wait for all tasks and filter out errors + tasks.awaitAll().filter { it > 0 } +} + +/** + * Simulate provider usage with or without a cache + * Note: In a real implementation, this would use the actual LLM providers + */ +fun simulateProviderUsage(useCache: Boolean) { + // Simulate provider creation and usage time + val createTime = 500L // 500ms to create a provider + val useTime = 1000L // 1000ms to use the provider + + if (useCache) { + // With caching - only creation time for first use + val provider = ProviderCache.get("sample-provider") + Thread.sleep(useTime) // Simulate provider usage time + } else { + // Without caching - pay creation cost every time + Thread.sleep(createTime) // Simulate provider creation time + Thread.sleep(useTime) // Simulate provider usage time + } +} + +/** + * A simple provider cache to simulate connection pooling + */ +class ProviderCache { + companion object { + private val cache = mutableMapOf() + var created = 0 + var retrieved = 0 + + fun get(key: String): Any { + return if (cache.containsKey(key)) { + retrieved++ + cache[key]!! + } else { + // Create new provider + created++ + Thread.sleep(500) // Simulate creation time + val provider = Any() // In reality this would be the provider + cache[key] = provider + provider + } + } + + fun size(): Int = cache.size + } +} + +/** + * Implementation Plan for Real Connection Pooling + * + * 1. Rust Side: + * - Create a thread-safe provider pool using Arc>>>> + * - Add pool configuration settings (max size, idle timeout, etc.) + * - Export get_provider/return_provider functions via FFI + * - Add pool statistics + * + * 2. Kotlin Side: + * - Create wrapper classes to manage the provider pool + * - Add utility functions for pool management + * - Integrate with completion requests + * + * 3. Benefits: + * - Reduced latency for repeated requests + * - Better resource utilization + * - Improved performance for parallel requests + * - More stable connection handling + * - Configurable pool behavior + */ \ No newline at end of file diff --git a/bindings/kotlin/example/PoolingDemo.kt b/bindings/kotlin/example/PoolingDemo.kt new file mode 100644 index 000000000000..ae828ea0d094 --- /dev/null +++ b/bindings/kotlin/example/PoolingDemo.kt @@ -0,0 +1,143 @@ +import kotlinx.coroutines.runBlocking +import uniffi.goose_llm.* +import kotlin.random.Random + +fun main() = runBlocking { + // Setup provider config + val providerName = System.getenv("PROVIDER_NAME") ?: "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerConfig = """{"host": "$host", "token": "$token"}""" + + val modelName = System.getenv("MODEL_NAME") ?: "goose-gpt-4-1" + val modelConfig = ModelConfig( + modelName, + 100000u, // context limit + 0.1f, // temperature + 200 // max tokens + ) + + // Configure the provider pool + println("Initializing and configuring provider pool...") + ProviderPool.initialize() + ProviderPool.configure( + maxSize = 5, + maxIdleSec = 300, + maxLifetimeSec = 3600, + maxUses = 50 + ) + + println("Running provider pool performance test...") + val tester = PoolPerformanceTester( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig + ) + + // Run a benchmark with simple questions + val result = tester.runBenchmark( + iterations = 10, // Run 10 iterations + parallelism = 3, // 3 parallel requests + messageSupplier = { + // Generate random test messages + val question = getRandomQuestion() + listOf( + Message( + role = Role.USER, + created = System.currentTimeMillis() / 1000, + content = listOf( + MessageContent.Text(TextContent(question)) + ) + ) + ) + } + ) + + // Print benchmark results + println("\nPerformance Benchmark Results") + println("============================") + println(result) + + // Live agent example (interactive loop) + println("\nWould you like to try an interactive demo? (y/n)") + val input = readlnOrNull() + if (input?.lowercase()?.startsWith("y") == true) { + runInteractiveDemo(providerName, providerConfig, modelConfig) + } +} + +fun runInteractiveDemo(providerName: String, providerConfig: String, modelConfig: ModelConfig) = runBlocking { + println("\nStarting interactive demo with connection pooling...\n") + + val service = ParallelCompletionService( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig, + usePool = true + ) + + val messages = mutableListOf() + + while (true) { + print("\nYou (or 'exit' to quit): ") + val userInput = readlnOrNull() ?: continue + + if (userInput.lowercase() == "exit") break + + // Add user message + messages.add( + Message( + role = Role.USER, + created = System.currentTimeMillis() / 1000, + content = listOf(MessageContent.Text(TextContent(userInput))) + ) + ) + + // Process with connection pooling + try { + print("Assistant: ") + val response = service.process(messages) + + // Add assistant response to conversation history + messages.add(response.message) + + // Extract text from message content + val responseText = response.message.content + .filterIsInstance() + .joinToString("\n") { it.v1.text } + + println(responseText) + + // Print some performance metrics + println("\n[Processing time: ${response.runtimeMetrics.totalTimeSec}s]") + } catch (e: Exception) { + println("Error: ${e.message}") + } + } + + // Print final pool stats + println("\nFinal provider pool statistics:") + println(ProviderPool.stats()) +} + +// List of random questions for the benchmark +fun getRandomQuestion(): String { + val questions = listOf( + "What is the capital of France?", + "How does photosynthesis work?", + "Explain the basics of quantum computing", + "What are the main causes of climate change?", + "Who wrote the novel Pride and Prejudice?", + "What is the difference between machine learning and AI?", + "How do vaccines work?", + "What are black holes?", + "Explain the theory of relativity", + "What is the Pythagorean theorem?", + "How does the internet work?", + "What is blockchain technology?", + "What's the difference between HTTP and HTTPS?", + "How does GPS navigation work?", + "What is the history of the Olympic Games?" + ) + return questions[Random.nextInt(questions.size)] +} \ No newline at end of file diff --git a/bindings/kotlin/example/ProviderPool.kt b/bindings/kotlin/example/ProviderPool.kt new file mode 100644 index 000000000000..3191ea6c3571 --- /dev/null +++ b/bindings/kotlin/example/ProviderPool.kt @@ -0,0 +1,235 @@ +import kotlinx.coroutines.* +import uniffi.goose_llm.* +import kotlin.system.measureTimeMillis + +/** + * Provider pool for managing and optimizing LLM provider connections + * + * This class wraps the underlying Rust provider pool system to make it easy + * to use from Kotlin. It provides methods to configure and monitor the pool. + * + * Note: The connection pool settings are passed to the Rust layer as u32 types, + * so very large values (> 2^31-1) will be clamped. + */ +class ProviderPool { + companion object { + /** + * Initialize the provider pool with default settings + */ + fun initialize() { + initProviderPool() + } + + /** + * Configure the provider pool with custom settings + * + * @param maxSize The maximum number of connections in the pool + * @param maxIdleSec Maximum time a connection can be idle before cleanup + * @param maxLifetimeSec Maximum lifetime of a connection + * @param maxUses Maximum number of uses for a connection + */ + fun configure(maxSize: Int = 10, maxIdleSec: Long = 300, + maxLifetimeSec: Long = 3600, maxUses: Int = 100) { + configureProviderPool(maxSize, maxIdleSec, maxLifetimeSec, maxUses) + } + + /** + * Get statistics about the current connection pool + * + * @return A string containing the pool statistics + */ + fun stats(): String { + return getPoolStats() + } + } +} + +/** + * Extension function to create a completion request with pool options + * + * @param usePool Whether to use the connection pool + */ +fun createPooledCompletionRequest( + providerName: String, + providerConfig: String, + modelConfig: ModelConfig, + systemPreamble: String, + messages: List, + extensions: List, + usePool: Boolean = true +): CompletionRequest { + return createCompletionRequest( + providerName, + providerConfig, + modelConfig, + systemPreamble, + messages, + extensions, + usePool + ) +} + +/** + * Parallel completion service for handling multiple requests efficiently + * + * This class helps manage multiple parallel completion requests with + * connection pooling for optimal performance. + */ +class ParallelCompletionService( + private val providerName: String, + private val providerConfig: String, + private val modelConfig: ModelConfig, + private val systemPreamble: String = "You are a helpful assistant.", + private val extensions: List = emptyList(), + private val maxConcurrency: Int = 5, + private val usePool: Boolean = true +) { + private val dispatcher = Dispatchers.IO.limitedParallelism(maxConcurrency) + + init { + // Initialize the provider pool + ProviderPool.initialize() + } + + /** + * Process multiple messages in parallel + * + * @param messages List of message lists to process + * @return List of completion responses + */ + suspend fun processInParallel(messages: List>): List = + coroutineScope { + messages.map { messageList -> + async(dispatcher) { + val request = createPooledCompletionRequest( + providerName, + providerConfig, + modelConfig, + systemPreamble, + messageList, + extensions, + usePool + ) + completion(request) + } + }.awaitAll() + } + + /** + * Process a single message list using the service configuration + * + * @param messages Message list to process + * @return Completion response + */ + suspend fun process(messages: List): CompletionResponse { + val request = createPooledCompletionRequest( + providerName, + providerConfig, + modelConfig, + systemPreamble, + messages, + extensions, + usePool + ) + return completion(request) + } +} + +/** + * Performance testing functions for comparing pooled vs non-pooled performance + */ +class PoolPerformanceTester( + private val providerName: String, + private val providerConfig: String, + private val modelConfig: ModelConfig, + private val systemPreamble: String = "You are a helpful assistant." +) { + // Configure the pool with default settings + init { + ProviderPool.initialize() + ProviderPool.configure() + } + + /** + * Run a performance benchmark comparing pooled vs non-pooled completions + * + * @param iterations Number of iterations to perform + * @param parallelism Maximum parallel requests + * @param messageSupplier Function to generate test messages + * @return Benchmark results + */ + suspend fun runBenchmark( + iterations: Int = 10, + parallelism: Int = 3, + messageSupplier: () -> List + ): BenchmarkResult = coroutineScope { + // Run test with pooling + val poolTimeMs = measureTimeMillis { + processRequests(iterations, parallelism, true, messageSupplier) + } + + delay(1000) // Give pool time to stabilize between tests + + // Run test without pooling + val nonPoolTimeMs = measureTimeMillis { + processRequests(iterations, parallelism, false, messageSupplier) + } + + // Calculate improvement percentage + val improvementPercent = if (nonPoolTimeMs > 0) { + ((nonPoolTimeMs - poolTimeMs) * 100.0) / nonPoolTimeMs + } else 0.0 + + BenchmarkResult( + iterations = iterations, + parallelism = parallelism, + pooledTimeMs = poolTimeMs, + nonPooledTimeMs = nonPoolTimeMs, + improvementPercent = improvementPercent, + poolStats = ProviderPool.stats() + ) + } + + private suspend fun processRequests( + iterations: Int, + parallelism: Int, + usePool: Boolean, + messageSupplier: () -> List + ) = coroutineScope { + val service = ParallelCompletionService( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = systemPreamble, + maxConcurrency = parallelism, + usePool = usePool + ) + + val messages = (1..iterations).map { messageSupplier() } + service.processInParallel(messages) + } +} + +/** + * Results from a pooling performance benchmark + */ +data class BenchmarkResult( + val iterations: Int, + val parallelism: Int, + val pooledTimeMs: Long, + val nonPooledTimeMs: Long, + val improvementPercent: Double, + val poolStats: String +) { + override fun toString(): String = buildString { + append("Benchmark Results:\n") + append("------------------\n") + append("Iterations: $iterations\n") + append("Parallelism: $parallelism\n") + append("Pooled time: ${pooledTimeMs}ms (${pooledTimeMs / iterations}ms per request)\n") + append("Non-pooled time: ${nonPooledTimeMs}ms (${nonPooledTimeMs / iterations}ms per request)\n") + append("Improvement: ${String.format("%.2f", improvementPercent)}%\n") + append("\nPool Statistics:\n") + append(poolStats) + } +} \ No newline at end of file diff --git a/bindings/kotlin/example/SimplePoolingExample.kt b/bindings/kotlin/example/SimplePoolingExample.kt new file mode 100644 index 000000000000..5f320d4ec184 --- /dev/null +++ b/bindings/kotlin/example/SimplePoolingExample.kt @@ -0,0 +1,116 @@ +import kotlinx.coroutines.* +import uniffi.goose_llm.* +import kotlin.system.measureNanoTime + +/** + * Simple example showing connection pooling in action. + * This example avoids the experimental/advanced features and just shows basic pooling. + */ +fun main() = runBlocking { + // Get provider details from environment variables + val providerName = System.getenv("PROVIDER_NAME") ?: "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerConfig = """{"host": "$host", "token": "$token"}""" + + val modelName = System.getenv("MODEL_NAME") ?: "goose-gpt-4-1" + val modelConfig = ModelConfig( + modelName = modelName, + contextLimit = 100000u, + temperature = 0.7f, + maxTokens = 1000 + ) + + println("Initializing connection pool...") + // Initialize the provider pool with default settings + try { + initProviderPool() + } catch (e: Exception) { + println("Warning: Pool initialization failed: ${e.message}") + println("Proceeding without connection pooling") + } + + // Set up a simple question + val question = "What is connection pooling and why is it useful?" + val messages = listOf( + Message( + role = Role.USER, + created = System.currentTimeMillis() / 1000, + content = listOf(MessageContent.Text(TextContent(question))) + ) + ) + + // First call without pooling for comparison + println("\nMaking request WITHOUT connection pooling...") + var duration1 = 0L + val request1 = createCompletionRequest( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = "You are a helpful assistant.", + messages = messages, + extensions = emptyList(), + usePool = false // Disable pooling for this request + ) + + // Process the completion + try { + var response1: CompletionResponse + duration1 = measureNanoTime { + response1 = completion(request1) + } + + // Extract and print response + val text1 = response1.message.content + .filterIsInstance() + .joinToString { it.v1.text } + + println("Response: ${text1.take(150)}...") + println("Time taken: ${duration1 / 1_000_000}ms") + } catch (e: Exception) { + println("Error with non-pooled request: ${e.message}") + } + + // Second call with pooling + println("\nMaking request WITH connection pooling...") + var duration2 = 0L + val request2 = createCompletionRequest( + providerName = providerName, + providerConfig = providerConfig, + modelConfig = modelConfig, + systemPreamble = "You are a helpful assistant.", + messages = messages, + extensions = emptyList(), + usePool = true // Enable pooling for this request + ) + + // Process the completion + try { + var response2: CompletionResponse + duration2 = measureNanoTime { + response2 = completion(request2) + } + + // Extract and print response + val text2 = response2.message.content + .filterIsInstance() + .joinToString { it.v1.text } + + println("Response: ${text2.take(150)}...") + println("Time taken: ${duration2 / 1_000_000}ms") + + // Calculate speedup + if (duration1 > 0) { + val speedup = (duration1.toDouble() / duration2.toDouble()) - 1.0 + println("Speedup with connection pooling: ${String.format("%.2f", speedup * 100)}%") + } + + // Show pool stats + println("\nProvider Pool Statistics:") + println(getPoolStats()) + } catch (e: Exception) { + println("Error with pooled request: ${e.message}") + } + + println("\nDone!") +} \ No newline at end of file diff --git a/bindings/kotlin/example/Usage.kt b/bindings/kotlin/example/Usage.kt index cdb06c8211db..5f2a2aaaf5d1 100644 --- a/bindings/kotlin/example/Usage.kt +++ b/bindings/kotlin/example/Usage.kt @@ -2,6 +2,16 @@ import kotlinx.coroutines.runBlocking import uniffi.goose_llm.* fun main() = runBlocking { + // Initialize the provider pool (optional, but good practice) + initProviderPool() + + // Configure the provider pool with custom settings + configureProviderPool( + maxSize = 10, // Max 10 connections in the pool + maxIdleSec = 300, // 5 minutes max idle time + maxLifetimeSec = 3600, // 1 hour max lifetime + maxUses = 100 // Max 100 uses per connection + ) val now = System.currentTimeMillis() / 1000 val msgs = listOf( // 1) User sends a plain-text prompt @@ -186,6 +196,7 @@ fun main() = runBlocking { ) )), extensions = extensions + usePool = true // Enable connection pooling ) val respToolErr = completion(reqToolErr) @@ -237,44 +248,107 @@ fun main() = runBlocking { suspend fun runUiExtraction(providerName: String, providerConfig: String) { val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI." + val messageText = """ + [ + { + "year": 2015, + "unique_artists": 71 + }, + { + "year": 2016, + "unique_artists": 51 + }, + { + "year": 2017, + "unique_artists": 121 + }, + { + "year": 2018, + "unique_artists": 92 + }, + { + "year": 2019, + "unique_artists": 377 + }, + { + "year": 2020, + "unique_artists": 335 + }, + { + "year": 2021, + "unique_artists": 383 + }, + { + "year": 2022, + "unique_artists": 444 + }, + { + "year": 2023, + "unique_artists": 510 + }, + { + "year": 2024, + "unique_artists": 627 + }, + { + "year": 2025, + "unique_artists": 243 + } +] +""".trimIndent() + val messages = listOf( Message( role = Role.USER, created = System.currentTimeMillis() / 1000, content = listOf( MessageContent.Text( - TextContent("Make a User Profile Form") + TextContent(messageText) ) ) ) ) + + val schema2 = """ + { + "type": "object", + "properties": { + "chartType": { + "const": "line", + "type": "string" + }, + "xAxis": { + "type": "array", + "items": { "type": "number" } + }, + "yAxis": { + "type": "array", + "items": { "type": "number" } + } + }, + "required": ["chartType", "xAxis", "yAxis"], + "additionalProperties": false +} +""".trimIndent(); + val schema = """{ - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["div","button","header","section","field","form"] - }, - "label": { "type": "string" }, - "children": { - "type": "array", - "items": { "${'$'}ref": "#" } - }, - "attributes": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" }, - "value": { "type": "string" } - }, - "required": ["name","value"], - "additionalProperties": false - } - } - }, - "required": ["type","label","children","attributes"], - "additionalProperties": false + "properties": { + "chartType": { + "const": "line", + "type": "string" + }, + "xAxis": { + "title": "Year", + "type": "string" + }, + "yAxis": { + "title": "Number of Unique Artists", + "type": "number" + } + }, + "type": "object", + "additionalProperties": false, + "required": ["chartType", "xAxis", "yAxis"] }""".trimIndent(); try { @@ -283,7 +357,7 @@ suspend fun runUiExtraction(providerName: String, providerConfig: String) { providerConfig = providerConfig, systemPrompt = systemPrompt, messages = messages, - schema = schema + schema = schema2 ) println("\nUI Extraction Output:\n${response}") } catch (e: ProviderException) { diff --git a/bindings/kotlin/uniffi/goose_llm/ProviderPool.kt b/bindings/kotlin/uniffi/goose_llm/ProviderPool.kt new file mode 100644 index 000000000000..9e42ab5c1779 --- /dev/null +++ b/bindings/kotlin/uniffi/goose_llm/ProviderPool.kt @@ -0,0 +1,81 @@ +// This file provides direct access to provider pooling functionality + +package uniffi.goose_llm + +/** + * Initialize the provider pool with default configuration + */ +fun initProviderPool() { + uniffiRustCall { _status -> + UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_init_provider_pool(_status) + } +} + +/** + * Configure the provider pool with custom settings + * + * @param maxSize Maximum number of providers in the pool + * @param maxIdleSec Maximum idle time in seconds before a provider is removed + * @param maxLifetimeSec Maximum lifetime in seconds for a provider + * @param maxUses Maximum number of uses for a provider + */ +fun configureProviderPool( + maxSize: Int, + maxIdleSec: Long, + maxLifetimeSec: Long, + maxUses: Int +) { + uniffiRustCall { _status -> + UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_configure_provider_pool( + maxSize.toUInt().toInt(), // Convert to match Rust's u32 + maxIdleSec, + maxLifetimeSec, + maxUses.toUInt().toInt(), // Convert to match Rust's u32 + _status + ) + } +} + +/** + * Get statistics about the provider pool + * + * @return A string representation of the pool statistics + */ +fun getPoolStats(): String { + return FfiConverterString.lift( + uniffiRustCall { _status -> + UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_get_pool_stats(_status) + } + ) +} + +/** + * Extension function to create a completion request with pool options + * + * @param usePool Whether to use the connection pool + */ +fun createCompletionRequest( + providerName: String, + providerConfig: Value, + modelConfig: ModelConfig, + systemPreamble: String, + messages: List, + extensions: List, + usePool: Boolean? = null +): CompletionRequest { + return FfiConverterTypeCompletionRequest.lift( + uniffiRustCall { _status -> + UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_create_completion_request_with_pool( + FfiConverterString.lower(providerName), + FfiConverterTypeValue.lower(providerConfig), + FfiConverterTypeModelConfig.lower(modelConfig), + FfiConverterString.lower(systemPreamble), + FfiConverterSequenceTypeMessage.lower(messages), + FfiConverterSequenceTypeExtensionConfig.lower(extensions), + // Convert Boolean? to Byte? for FFI + if (usePool == null) null else if (usePool) 1.toByte() else 0.toByte(), + _status + ) + } + ) +} \ No newline at end of file diff --git a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt index 76e60aaf7441..238fd21eb93e 100644 --- a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt +++ b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt @@ -835,6 +835,33 @@ internal interface UniffiLib : Library { `extensions`: RustBuffer.ByValue, uniffi_out_err: UniffiRustCallStatus, ): RustBuffer.ByValue + + fun uniffi_goose_llm_fn_func_create_completion_request_with_pool( + `providerName`: RustBuffer.ByValue, + `providerConfig`: RustBuffer.ByValue, + `modelConfig`: RustBuffer.ByValue, + `systemPreamble`: RustBuffer.ByValue, + `messages`: RustBuffer.ByValue, + `extensions`: RustBuffer.ByValue, + `usePool`: Byte?, + uniffi_out_err: UniffiRustCallStatus, + ): RustBuffer.ByValue + + fun uniffi_goose_llm_fn_func_init_provider_pool( + uniffi_out_err: UniffiRustCallStatus, + ): Unit + + fun uniffi_goose_llm_fn_func_configure_provider_pool( + maxSize: Int, // This is a u32 in Rust + maxIdleSec: Long, // u64 in Rust + maxLifetimeSec: Long, // u64 in Rust + maxUses: Int, // This is a u32 in Rust + uniffi_out_err: UniffiRustCallStatus, + ): Unit + + fun uniffi_goose_llm_fn_func_get_pool_stats( + uniffi_out_err: UniffiRustCallStatus, + ): RustBuffer.ByValue fun uniffi_goose_llm_fn_func_create_tool_config( `name`: RustBuffer.ByValue, @@ -3077,8 +3104,5 @@ suspend fun `generateTooltip`( fun `printMessages`(`messages`: List) = uniffiRustCall { _status -> - UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_print_messages( - FfiConverterSequenceTypeMessage.lower(`messages`), - _status, - ) + UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_print_messages(FfiConverterSequenceTypeMessage.lower(`messages`), _status) } diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml index 17723e31aac4..ee7111b1209c 100644 --- a/crates/goose-llm/Cargo.toml +++ b/crates/goose-llm/Cargo.toml @@ -45,6 +45,10 @@ indoc = "1.0" # https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22 uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] } tokio = { version = "1.43", features = ["time", "sync"] } +lazy_static = "1.5" +parking_lot = "0.12" +futures = "0.3" +md5 = "0.7" [dev-dependencies] criterion = "0.5" diff --git a/crates/goose-llm/examples/provider_pool.rs b/crates/goose-llm/examples/provider_pool.rs new file mode 100644 index 000000000000..c426d7b779c3 --- /dev/null +++ b/crates/goose-llm/examples/provider_pool.rs @@ -0,0 +1,81 @@ +use anyhow::Result; +use futures::future::join_all; +use goose_llm::{ + completion, configure_provider_pool, get_pool_stats, init_provider_pool, Message, ModelConfig, +}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize the provider pool with default configuration + init_provider_pool(); + + // Configure the pool with custom settings if desired + // Maximum pool size: 5 + // Maximum idle time: 60 seconds + // Maximum lifetime: 300 seconds (5 minutes) + // Maximum uses: 50 + configure_provider_pool(5, 60, 300, 50); + + // Create a request template + let model = ModelConfig::new("gpt-4o".to_string()); + + // Get the OpenAI API key from environment variable + let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set"); + + // Create the provider config + let provider_config = json!({ + "api_key": api_key, + "timeout": 60 + }); + + // Create multiple concurrent requests + let mut handles = Vec::new(); + for i in 0..3 { + let model_clone = model.clone(); + let provider_config_clone = provider_config.clone(); + + let handle = tokio::spawn(async move { + // Create a simple message + let messages = vec![Message::user().with_text(&format!( + "Count from {} to {}", + i * 5 + 1, + i * 5 + 5 + ))]; + + // Create a request + let req = goose_llm::types::completion::create_completion_request( + "openai", + provider_config_clone, + model_clone, + Some("You are a helpful assistant.".to_string()), + None, + messages, + vec![], + Some(true), // use the provider pool + ); + + // Execute the request + let response = completion(req).await?; + + // Print the response + println!( + "Request {}: {:?}\n", + i, + response.message.content[0] + ); + + Ok::<_, anyhow::Error>(()) + }); + + handles.push(handle); + } + + // Wait for all requests to complete + join_all(handles).await; + + // Print pool statistics + println!("\nProvider Pool Statistics:\n{}", get_pool_stats()); + + Ok(()) +} \ No newline at end of file diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index d39b1b8db830..913527364947 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -3,11 +3,12 @@ use std::{collections::HashMap, time::Instant}; use anyhow::Result; use chrono::Utc; use serde_json::Value; +use tracing::{debug, error, info}; use crate::{ message::{Message, MessageContent}, prompt_template, - providers::create, + providers::{create, global_pool_manager, base::Provider}, types::{ completion::{ CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig, @@ -24,11 +25,115 @@ pub fn print_messages(messages: Vec) { } } +/// Initialize the provider pool with default configuration +#[uniffi::export] +pub fn init_provider_pool() { + // Initialize the provider pool with default configuration + global_pool_manager(); + debug!("Provider pool initialized with default configuration"); +} + +/// Configure the provider pool with custom settings +#[uniffi::export] +pub fn configure_provider_pool(max_size: u32, max_idle_sec: u64, max_lifetime_sec: u64, max_uses: u32) { + use std::time::Duration; + use crate::providers::PoolConfig; + + let _config = PoolConfig { + max_size: max_size as usize, + max_idle_time: Duration::from_secs(max_idle_sec), + max_lifetime: Duration::from_secs(max_lifetime_sec), + max_uses: max_uses as usize, + }; + + global_pool_manager(); // Initialize if not already + + // We can't directly configure the global pool manager after it's initialized, + // but this is useful for per-provider configuration + info!("Provider pool configured with max_size={}, max_idle_time={}s, max_lifetime={}s, max_uses={}", + max_size, max_idle_sec, max_lifetime_sec, max_uses); +} + +/// Get statistics about the provider pool +#[uniffi::export] +pub fn get_pool_stats() -> String { + let stats = global_pool_manager().get_all_stats(); + + if stats.is_empty() { + return "No active provider pools".to_string(); + } + + let mut result = String::new(); + for (name, stat) in stats { + result.push_str(&format!("Pool: {}\n", name)); + result.push_str(&format!(" Created: {}\n", stat.created)); + result.push_str(&format!(" Borrowed: {}\n", stat.borrowed)); + result.push_str(&format!(" Returned: {}\n", stat.returned)); + result.push_str(&format!(" Errors: {}\n", stat.errors)); + result.push_str(&format!(" Max Pool Size: {}\n", stat.max_pool_size)); + result.push_str(&format!(" Current Pool Size: {}\n", stat.current_pool_size)); + result.push_str(&format!(" Waiting: {}\n", stat.waiting)); + result.push_str("\n"); + } + + result +} + /// Public API for the Goose LLM completion function #[uniffi::export(async_runtime = "tokio")] pub async fn completion(req: CompletionRequest) -> Result { let start_total = Instant::now(); + let system_prompt = construct_system_prompt( + &req.system_preamble, + &req.system_prompt_override, + &req.extensions, + )?; + let tools = collect_prefixed_tools(&req.extensions); + + // Create a pooled provider or a direct provider based on the request + if req.use_pool { + // Try to get a provider from the pool by calling directly on the pool manager + let pool_manager = global_pool_manager(); + let pool = pool_manager.get_or_create_pool( + &req.provider_name, + req.provider_config.clone(), + req.model_config.clone(), + ); + + // Clone the Arc so we can move it into the match + let pool_clone = pool.clone(); + match pool_clone.get().await { + Ok(pooled_provider) => { + // Call the pooled provider + let start_provider_time = Instant::now(); + let mut response = pooled_provider + .complete(&system_prompt, &req.messages, &tools) + .await?; + + let provider_elapsed_sec = start_provider_time.elapsed().as_secs_f32(); + let usage_tokens = response.usage.total_tokens; + + let tool_configs = collect_prefixed_tool_configs(&req.extensions); + update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?; + + return Ok(CompletionResponse::new( + response.message, + response.model, + response.usage, + calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens), + )); + }, + Err(e) => { + error!("Failed to get provider from pool: {}", e); + // Fall back to creating a provider directly + debug!("Falling back to direct provider creation"); + } + } + } + + // Create a provider directly (either by choice or as fallback) + debug!("Using direct provider creation"); let provider = create( &req.provider_name, req.provider_config.clone(), diff --git a/crates/goose-llm/src/lib.rs b/crates/goose-llm/src/lib.rs index cd698356bcef..50b74da475fa 100644 --- a/crates/goose-llm/src/lib.rs +++ b/crates/goose-llm/src/lib.rs @@ -9,7 +9,7 @@ pub mod providers; mod structured_outputs; pub mod types; -pub use completion::completion; +pub use completion::{completion, configure_provider_pool, get_pool_stats, init_provider_pool}; pub use message::Message; pub use model::ModelConfig; pub use structured_outputs::generate_structured_outputs; diff --git a/crates/goose-llm/src/providers/base.rs b/crates/goose-llm/src/providers/base.rs index dcfecbd1e7f3..280e23714a8f 100644 --- a/crates/goose-llm/src/providers/base.rs +++ b/crates/goose-llm/src/providers/base.rs @@ -62,7 +62,7 @@ impl ProviderExtractResponse { /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] -pub trait Provider: Send + Sync { +pub trait Provider: Send + Sync + std::fmt::Debug { /// Generate the next message using the configured model and other parameters /// /// # Arguments diff --git a/crates/goose-llm/src/providers/factory.rs b/crates/goose-llm/src/providers/factory.rs index a70be3d44ef8..a532213e5907 100644 --- a/crates/goose-llm/src/providers/factory.rs +++ b/crates/goose-llm/src/providers/factory.rs @@ -9,6 +9,7 @@ use super::{ }; use crate::model::ModelConfig; +/// Create a new provider instance directly (without pooling) pub fn create( name: &str, provider_config: serde_json::Value, @@ -27,3 +28,4 @@ pub fn create( _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), } } + diff --git a/crates/goose-llm/src/providers/mod.rs b/crates/goose-llm/src/providers/mod.rs index c808938048f9..d881f450386b 100644 --- a/crates/goose-llm/src/providers/mod.rs +++ b/crates/goose-llm/src/providers/mod.rs @@ -4,7 +4,9 @@ pub mod errors; mod factory; pub mod formats; pub mod openai; +pub mod pool; pub mod utils; pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage}; pub use factory::create; +pub use pool::{global_pool_manager, PoolConfig}; diff --git a/crates/goose-llm/src/providers/pool.rs b/crates/goose-llm/src/providers/pool.rs new file mode 100644 index 000000000000..546317c01db8 --- /dev/null +++ b/crates/goose-llm/src/providers/pool.rs @@ -0,0 +1,512 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use async_trait::async_trait; +use parking_lot::Mutex; +use serde_json::Value; +use tokio::sync::Semaphore; +use tracing::debug; + +use super::{ + base::{Provider, ProviderCompleteResponse, ProviderExtractResponse}, + errors::ProviderError, + factory, +}; +use crate::{message::Message, model::ModelConfig, types::core::Tool}; + +/// Statistics for the provider pool +#[derive(Debug, Clone, Default)] +pub struct PoolStats { + pub created: usize, + pub borrowed: usize, + pub returned: usize, + pub errors: usize, + pub max_pool_size: usize, + pub current_pool_size: usize, + pub waiting: usize, +} + +/// A pool entry containing a provider and metadata +#[derive(Debug)] +struct PooledProvider { + provider: Arc, + created_at: Instant, + last_used: Instant, + use_count: usize, +} + +impl PooledProvider { + fn new(provider: Arc) -> Self { + let now = Instant::now(); + Self { + provider, + created_at: now, + last_used: now, + use_count: 0, + } + } + + fn used(&mut self) { + self.last_used = Instant::now(); + self.use_count += 1; + } +} + +/// Configuration for the provider pool +#[derive(Debug, Clone)] +pub struct PoolConfig { + pub max_size: usize, + pub max_idle_time: Duration, + pub max_lifetime: Duration, + pub max_uses: usize, +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_size: 10, + max_idle_time: Duration::from_secs(300), // 5 minutes + max_lifetime: Duration::from_secs(3600), // 1 hour + max_uses: 100, + } + } +} + +/// ProviderPool manages a pool of provider instances +#[derive(Debug)] +pub struct ProviderPool { + // Pool configuration + config: PoolConfig, + + // Provider creation parameters + provider_name: String, + provider_config: Value, + model_config: ModelConfig, + + // Pool state + available: RwLock>, + in_use: AtomicUsize, + stats: Arc>, + + // Concurrency control + semaphore: Arc, +} + +impl ProviderPool { + /// Create a new provider pool + pub fn new( + provider_name: String, + provider_config: Value, + model_config: ModelConfig, + config: PoolConfig, + ) -> Self { + let stats = Arc::new(RwLock::new(PoolStats { + max_pool_size: config.max_size, + ..PoolStats::default() + })); + + Self { + config: config.clone(), + provider_name, + provider_config, + model_config, + available: RwLock::new(Vec::with_capacity(config.max_size)), + in_use: AtomicUsize::new(0), + stats, + semaphore: Arc::new(Semaphore::new(config.max_size)), + } + } + + /// Get a provider from the pool + pub async fn get(self: Arc) -> Result { + // Acquire a permit from the semaphore to limit concurrent requests + let _permit = self + .semaphore + .clone() + .acquire_owned() + .await + .map_err(|e| ProviderError::ExecutionError(format!("Failed to acquire semaphore: {}", e)))?; + + // Update waiting count + { + let mut stats = self.stats.write().unwrap(); + stats.waiting += 1; + } + + // Try to get a provider from the pool + let provider = self.get_or_create_provider().await?; + + // Update stats + { + let mut stats = self.stats.write().unwrap(); + stats.waiting -= 1; + stats.borrowed += 1; + stats.current_pool_size = self.available.read().unwrap().len() + self.in_use.load(Ordering::SeqCst); + } + + // Return the provider wrapped in a guard + Ok(PooledProviderGuard { + pool: self, + provider: Some(provider), + returned: false, + }) + } + + /// Get pool statistics + pub fn stats(&self) -> PoolStats { + self.stats.read().unwrap().clone() + } + + /// Internal method to get a provider from the pool or create a new one + async fn get_or_create_provider(&self) -> Result, ProviderError> { + // Try to get a provider from the pool first + if let Some(mut pooled) = self.take_available() { + // Check if the provider is still valid + let now = Instant::now(); + let idle_time = now.duration_since(pooled.last_used); + let lifetime = now.duration_since(pooled.created_at); + + if idle_time > self.config.max_idle_time + || lifetime > self.config.max_lifetime + || pooled.use_count >= self.config.max_uses { + // Provider expired, create a new one + debug!( + "Provider expired: idle_time={:?}, lifetime={:?}, use_count={}", + idle_time, lifetime, pooled.use_count + ); + drop(pooled); // Explicitly drop the expired provider + self.create_new_provider().await + } else { + // Update usage stats + pooled.used(); + self.in_use.fetch_add(1, Ordering::SeqCst); + Ok(pooled.provider) + } + } else { + // No available provider, create a new one + self.create_new_provider().await + } + } + + /// Take an available provider from the pool + fn take_available(&self) -> Option { + let mut pool = self.available.write().unwrap(); + if !pool.is_empty() { + Some(pool.remove(0)) + } else { + None + } + } + + /// Create a new provider + async fn create_new_provider(&self) -> Result, ProviderError> { + let provider = factory::create( + &self.provider_name, + self.provider_config.clone(), + self.model_config.clone(), + ) + .map_err(|e| ProviderError::ExecutionError(format!("Failed to create provider: {}", e)))?; + + // Update stats + { + let mut stats = self.stats.write().unwrap(); + stats.created += 1; + } + + self.in_use.fetch_add(1, Ordering::SeqCst); + Ok(provider) + } + + /// Return a provider to the pool + fn return_provider(&self, provider: Arc) { + let mut pool = self.available.write().unwrap(); + if pool.len() < self.config.max_size { + pool.push(PooledProvider::new(provider)); + } + // else drop the provider (it goes over capacity) + + self.in_use.fetch_sub(1, Ordering::SeqCst); + + // Update stats + { + let mut stats = self.stats.write().unwrap(); + stats.returned += 1; + stats.current_pool_size = pool.len() + self.in_use.load(Ordering::SeqCst); + } + } + + /// Handle an error with a provider + fn handle_error(&self) { + self.in_use.fetch_sub(1, Ordering::SeqCst); + + // Update stats + { + let mut stats = self.stats.write().unwrap(); + stats.errors += 1; + stats.current_pool_size = self.available.read().unwrap().len() + self.in_use.load(Ordering::SeqCst); + } + } + + /// Clean up idle providers in the pool + pub fn cleanup_idle(&self) -> usize { + let mut pool = self.available.write().unwrap(); + let now = Instant::now(); + let initial_size = pool.len(); + + // Remove providers that have been idle too long + pool.retain(|p| { + let idle_time = now.duration_since(p.last_used); + let lifetime = now.duration_since(p.created_at); + + idle_time <= self.config.max_idle_time + && lifetime <= self.config.max_lifetime + && p.use_count < self.config.max_uses + }); + + let removed = initial_size - pool.len(); + + // Update stats + if removed > 0 { + let mut stats = self.stats.write().unwrap(); + stats.current_pool_size = pool.len() + self.in_use.load(Ordering::SeqCst); + debug!("Cleaned up {} idle providers", removed); + } + + removed + } +} + +/// A guard that returns a provider to the pool when dropped +#[derive(Debug)] +pub struct PooledProviderGuard { + pool: Arc, + provider: Option>, + returned: bool, +} + +impl Drop for PooledProviderGuard { + fn drop(&mut self) { + // If the guard hasn't already returned the provider, return it now + if !self.returned { + if let Some(provider) = self.provider.take() { + self.pool.return_provider(provider); + } + } + } +} + +impl PooledProviderGuard { + /// Return the provider to the pool early + pub fn return_to_pool(mut self) { + if let Some(provider) = self.provider.take() { + self.pool.return_provider(provider); + self.returned = true; + } + } + + /// Mark the provider as having an error + pub fn handle_error(mut self) { + self.provider.take(); + self.pool.handle_error(); + self.returned = true; + } +} + +#[async_trait] +impl Provider for PooledProviderGuard { + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + match &self.provider { + Some(provider) => provider.complete(system, messages, tools).await, + None => Err(ProviderError::ExecutionError("Provider is not available".into())), + } + } + + async fn extract( + &self, + system: &str, + messages: &[Message], + schema: &Value + ) -> Result { + match &self.provider { + Some(provider) => provider.extract(system, messages, schema).await, + None => Err(ProviderError::ExecutionError("Provider is not available".into())), + } + } +} + +/// A provider pool manager that manages multiple pools +#[derive(Default)] +pub struct ProviderPoolManager { + pools: Arc>>>, + pool_configs: Arc>>, + default_config: PoolConfig, +} + +impl ProviderPoolManager { + /// Create a new provider pool manager with default configuration + pub fn new() -> Self { + Self { + pools: Arc::new(RwLock::new(HashMap::new())), + pool_configs: Arc::new(RwLock::new(HashMap::new())), + default_config: PoolConfig::default(), + } + } + + /// Set the default pool configuration + pub fn with_default_config(mut self, config: PoolConfig) -> Self { + self.default_config = config; + self + } + + /// Set a pool configuration for a specific provider + pub fn set_pool_config(&self, provider_name: &str, config: PoolConfig) { + let mut configs = self.pool_configs.write().unwrap(); + configs.insert(provider_name.to_string(), config); + } + + /// Get the pool configuration for a provider + fn get_pool_config(&self, provider_name: &str) -> PoolConfig { + let configs = self.pool_configs.read().unwrap(); + configs + .get(provider_name) + .cloned() + .unwrap_or_else(|| self.default_config.clone()) + } + + /// Get or create a provider pool + pub fn get_or_create_pool( + &self, + provider_name: &str, + provider_config: Value, + model_config: ModelConfig, + ) -> Arc { + let mut pools = self.pools.write().unwrap(); + + let key = self.create_pool_key(provider_name, &provider_config, &model_config); + + pools.entry(key.clone()).or_insert_with(|| { + let config = self.get_pool_config(provider_name); + Arc::new(ProviderPool::new( + provider_name.to_string(), + provider_config, + model_config, + config, + )) + }).clone() + } + + /// Create a key for the pool based on provider name, config, and model + fn create_pool_key(&self, provider_name: &str, provider_config: &Value, model_config: &ModelConfig) -> String { + // Create a key that uniquely identifies this provider configuration + // We use the provider name, a hash of the provider config, and the model name + let config_hash = format!("{:x}", md5::compute(provider_config.to_string())); + format!("{}:{}:{}", provider_name, config_hash, model_config.model_name) + } + + /// Get statistics for all pools + pub fn get_all_stats(&self) -> HashMap { + let pools = self.pools.read().unwrap(); + let mut stats = HashMap::with_capacity(pools.len()); + + for (key, pool) in pools.iter() { + stats.insert(key.clone(), pool.stats()); + } + + stats + } + + /// Clean up idle providers in all pools + pub fn cleanup_all_idle(&self) -> usize { + let pools = self.pools.read().unwrap(); + let mut total_removed = 0; + + for pool in pools.values() { + total_removed += pool.cleanup_idle(); + } + + total_removed + } + + /// Start a background task that periodically cleans up idle providers + pub fn start_cleanup_task(&self, interval: Duration) -> tokio::task::JoinHandle<()> { + let pools = self.pools.clone(); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(interval); + + loop { + interval.tick().await; + + let pools_ref = pools.read().unwrap(); + for pool in pools_ref.values() { + let removed = pool.cleanup_idle(); + if removed > 0 { + debug!("Cleaned up {} idle providers", removed); + } + } + } + }) + } + +} + +// Create a global provider pool manager +lazy_static::lazy_static! { + static ref GLOBAL_POOL_MANAGER: Mutex> = Mutex::new(None); +} + +/// Initialize the global provider pool manager +pub fn init_global_pool_manager(config: Option) -> &'static ProviderPoolManager { + let mut global = GLOBAL_POOL_MANAGER.lock(); + + if global.is_none() { + let manager = match config { + Some(config) => ProviderPoolManager::new().with_default_config(config), + None => ProviderPoolManager::new(), + }; + + // Start the cleanup task + let cleanup_interval = Duration::from_secs(60); // 1 minute + manager.start_cleanup_task(cleanup_interval); + + *global = Some(manager); + } + + // SAFETY: This is safe because: + // 1. We never remove the pool manager once initialized (it lives for the program duration) + // 2. The mutex ensures thread-safe access to the manager + // 3. The static reference is only to the contained manager which has a static lifetime + let static_manager = unsafe { + let manager_ref = global.as_ref().unwrap(); + std::mem::transmute::<&ProviderPoolManager, &'static ProviderPoolManager>(manager_ref) + }; + + static_manager +} + +/// Get the global provider pool manager, initializing it if needed +pub fn global_pool_manager() -> &'static ProviderPoolManager { + let global = GLOBAL_POOL_MANAGER.lock(); + + if let Some(manager) = &*global { + // SAFETY: This is safe because the ProviderPoolManager is stored in a static Mutex + // and lives for the entire program duration + unsafe { std::mem::transmute::<&ProviderPoolManager, &'static ProviderPoolManager>(manager) } + } else { + drop(global); // Release lock before calling init + init_global_pool_manager(None) + } +} + diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index 21e0bcd9ddd3..e5ed68a62b82 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -20,6 +20,12 @@ pub struct CompletionRequest { pub system_prompt_override: Option, pub messages: Vec, pub extensions: Vec, + #[serde(default = "default_use_pool")] + pub use_pool: bool, +} + +fn default_use_pool() -> bool { + true } impl CompletionRequest { @@ -40,8 +46,14 @@ impl CompletionRequest { system_preamble, messages, extensions, + use_pool: true, } } + + pub fn with_pool_option(mut self, use_pool: bool) -> Self { + self.use_pool = use_pool; + self + } } #[uniffi::export(default(system_preamble = None, system_prompt_override = None))] @@ -53,8 +65,9 @@ pub fn create_completion_request( system_prompt_override: Option, messages: Vec, extensions: Vec, + use_pool: Option, ) -> CompletionRequest { - CompletionRequest::new( + let request = CompletionRequest::new( provider_name.to_string(), provider_config, model_config, @@ -62,7 +75,13 @@ pub fn create_completion_request( system_prompt_override, messages, extensions, - ) + ); + + if let Some(use_pool) = use_pool { + request.with_pool_option(use_pool) + } else { + request + } } uniffi::custom_type!(CompletionRequest, String, {