From 05699cc59ba639b3fd9081bb4e8ecad46bdbfcec Mon Sep 17 00:00:00 2001 From: John Landa Date: Mon, 1 Sep 2025 22:03:08 -0700 Subject: [PATCH 1/2] feat: add streaming support to Tetrate Agent Router Service provider Signed-off-by: John Landa --- crates/goose/src/providers/tetrate.rs | 57 ++++- crates/goose/tests/tetrate_streaming.rs | 284 ++++++++++++++++++++++++ 2 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 crates/goose/tests/tetrate_streaming.rs diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index 2aacfc9e0b9c..ac5582c6b048 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,14 +1,22 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; -use serde_json::Value; +use futures::TryStreamExt; +use serde_json::{json, Value}; +use std::io; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::formats::openai::response_to_streaming_message; use super::retry::ProviderRetry; use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, - is_google_model, + handle_status_openai_compat, is_google_model, }; use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL; use crate::conversation::message::Message; @@ -178,6 +186,49 @@ impl Provider for TetrateProvider { Ok((message, ProviderUsage::new(model, usage))) } + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request( + &self.model, + system, + messages, + tools, + &super::utils::ImageFormat::OpenAi, + )?; + + // Enable streaming + payload["stream"] = json!(true); + payload["stream_options"] = json!({ + "include_usage": true, + }); + + let response = self + .api_client + .response_post("v1/chat/completions", &payload) + .await?; + + let response = handle_status_openai_compat(response).await?; + let stream = response.bytes_stream().map_err(io::Error::other); + let model_config = self.model.clone(); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } + /// Fetch supported models from Tetrate Agent Router Service API (only models with tool support) async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Use the existing api_client which already has authentication configured diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs new file mode 100644 index 000000000000..7f94411bd8e6 --- /dev/null +++ b/crates/goose/tests/tetrate_streaming.rs @@ -0,0 +1,284 @@ +use anyhow::Result; +use futures::StreamExt; +use goose::conversation::message::{Message, MessageContent}; +use goose::providers::base::Provider; +use goose::providers::tetrate::TetrateProvider; +use goose::model::ModelConfig; +use rmcp::model::Tool; +use rmcp::object; +use serial_test::serial; + +/// Test module for Tetrate Agent Router Service streaming functionality +#[cfg(test)] +mod tetrate_streaming_tests { + use super::*; + + fn create_test_provider() -> Result { + // Create a test provider with the default model + let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; + TetrateProvider::from_env(model_config) + } + + #[tokio::test] + #[serial] + #[ignore] // Ignore by default, run with --ignored flag when API key is available + async fn test_tetrate_streaming_basic() -> Result<()> { + let provider = create_test_provider()?; + + let messages = vec![ + Message::user().with_text("Count from 1 to 5, one number at a time.") + ]; + + let mut stream = provider + .stream("You are a helpful assistant that counts numbers.", &messages, &[]) + .await?; + + let mut chunk_count = 0; + let mut content_chunks = Vec::new(); + + while let Some(result) = stream.next().await { + let (message, usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + if !text.is_empty() { + content_chunks.push(text); + } + } + + // Check if we have usage information in the final chunk + if usage.is_some() { + println!("Received usage information in chunk {}", chunk_count); + } + } + + assert!(chunk_count > 0, "Should receive at least one chunk"); + assert!(!content_chunks.is_empty(), "Should receive some content"); + + let full_content = content_chunks.join(""); + println!("Full streamed content: {}", full_content); + + // Verify the response contains numbers + assert!(full_content.contains('1'), "Response should contain number 1"); + assert!(full_content.contains('5'), "Response should contain number 5"); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_with_tools() -> Result<()> { + let provider = create_test_provider()?; + + // Define a simple tool + let weather_tool = Tool::new( + "get_weather", + "Get the current weather for a location", + object!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + }), + ); + + let messages = vec![ + Message::user().with_text("What's the weather in San Francisco?") + ]; + + let mut stream = provider + .stream("You are a helpful assistant with access to weather information.", &messages, &[weather_tool]) + .await?; + + let mut received_tool_call = false; + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (message, _usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + // Check if message contains tool requests + for content in &msg.content { + if matches!(content, MessageContent::ToolRequest(_)) { + received_tool_call = true; + println!("Received tool call in chunk {}", chunk_count); + } + } + } + } + + assert!(chunk_count > 0, "Should receive at least one chunk"); + // Note: Tool calls might not be supported in streaming for all models + // This is more of a capability test than a requirement + if received_tool_call { + println!("✓ Streaming with tools is supported"); + } else { + println!("⚠ Streaming with tools may not be fully supported"); + } + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_empty_response() -> Result<()> { + let provider = create_test_provider()?; + + // This might result in a very short or empty response + let messages = vec![ + Message::user().with_text("") + ]; + + let mut stream = provider + .stream("You are a helpful assistant.", &messages, &[]) + .await?; + + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (_message, _usage) = result?; + chunk_count += 1; + } + + // Even with empty input, we should get at least one chunk (possibly with finish_reason) + assert!(chunk_count > 0, "Should receive at least one chunk even with empty input"); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_long_response() -> Result<()> { + let provider = create_test_provider()?; + + let messages = vec![ + Message::user().with_text("Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.") + ]; + + let mut stream = provider + .stream("You are a helpful assistant that writes detailed essays.", &messages, &[]) + .await?; + + let mut chunk_count = 0; + let mut total_content_length = 0; + + while let Some(result) = stream.next().await { + let (message, usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + total_content_length += text.len(); + } + + // Final chunk should have usage information + if let Some(usage_info) = usage { + println!("Final usage: {:?}", usage_info.usage); + assert!(usage_info.usage.output_tokens.unwrap_or(0) > 0, "Should have output tokens"); + } + } + + println!("Received {} chunks with total content length: {}", chunk_count, total_content_length); + + // For a detailed essay, we expect multiple chunks and substantial content + assert!(chunk_count > 5, "Long response should be streamed in multiple chunks"); + assert!(total_content_length > 100, "Essay should have substantial content"); + + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_tetrate_streaming_error_handling() -> Result<()> { + // Test with invalid API key to ensure error handling works + std::env::set_var("TETRATE_API_KEY", "invalid-key-for-testing"); + + let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; + let provider = TetrateProvider::from_env(model_config)?; + + let messages = vec![ + Message::user().with_text("Hello") + ]; + + let result = provider + .stream("You are a helpful assistant.", &messages, &[]) + .await; + + // We expect this to fail with an authentication error + assert!(result.is_err(), "Should fail with invalid API key"); + + // Clean up + std::env::remove_var("TETRATE_API_KEY"); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_concurrent_streams() -> Result<()> { + let provider = create_test_provider()?; + + // Create multiple concurrent streams + let messages1 = vec![Message::user().with_text("Say 'Stream 1'")]; + let messages2 = vec![Message::user().with_text("Say 'Stream 2'")]; + + let stream1 = provider + .stream("You are a helpful assistant.", &messages1, &[]) + .await?; + + let stream2 = provider + .stream("You are a helpful assistant.", &messages2, &[]) + .await?; + + // Process both streams concurrently + let (result1, result2) = tokio::join!( + process_stream(stream1, "Stream 1"), + process_stream(stream2, "Stream 2") + ); + + let content1 = result1?; + let content2 = result2?; + + println!("Stream 1 content: {}", content1); + println!("Stream 2 content: {}", content2); + + assert!(content1.contains("Stream 1") || content1.contains("1"), "First stream should mention Stream 1"); + assert!(content2.contains("Stream 2") || content2.contains("2"), "Second stream should mention Stream 2"); + + Ok(()) + } + + // Helper function to process a stream and collect content + async fn process_stream( + mut stream: goose::providers::base::MessageStream, + label: &str + ) -> Result { + let mut content = String::new(); + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (message, _usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + if !text.is_empty() { + content.push_str(&text); + } + } + } + + println!("{}: Received {} chunks", label, chunk_count); + Ok(content) + } +} \ No newline at end of file From 82f909bd83bbaa1fc69383101627259876a00006 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Mon, 8 Sep 2025 15:34:39 +1000 Subject: [PATCH 2/2] return streaming flag --- crates/goose/src/providers/tetrate.rs | 15 +- crates/goose/tests/tetrate_streaming.rs | 191 ++++++++++++++---------- 2 files changed, 123 insertions(+), 83 deletions(-) diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index ac5582c6b048..51951bae8160 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -46,6 +46,7 @@ pub struct TetrateProvider { #[serde(skip)] api_client: ApiClient, model: ModelConfig, + supports_streaming: bool, } impl_provider_default!(TetrateProvider); @@ -64,7 +65,11 @@ impl TetrateProvider { .with_header("HTTP-Referer", "https://block.github.io/goose")? .with_header("X-Title", "Goose")?; - Ok(Self { api_client, model }) + Ok(Self { + api_client, + model, + supports_streaming: true, + }) } async fn post(&self, payload: &Value) -> Result { @@ -199,7 +204,7 @@ impl Provider for TetrateProvider { tools, &super::utils::ImageFormat::OpenAi, )?; - + // Enable streaming payload["stream"] = json!(true); payload["stream_options"] = json!({ @@ -210,7 +215,7 @@ impl Provider for TetrateProvider { .api_client .response_post("v1/chat/completions", &payload) .await?; - + let response = handle_status_openai_compat(response).await?; let stream = response.bytes_stream().map_err(io::Error::other); let model_config = self.model.clone(); @@ -302,4 +307,8 @@ impl Provider for TetrateProvider { models.sort(); Ok(Some(models)) } + + fn supports_streaming(&self) -> bool { + self.supports_streaming + } } diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs index 7f94411bd8e6..27d5ea1f5183 100644 --- a/crates/goose/tests/tetrate_streaming.rs +++ b/crates/goose/tests/tetrate_streaming.rs @@ -1,9 +1,9 @@ use anyhow::Result; use futures::StreamExt; use goose::conversation::message::{Message, MessageContent}; +use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::tetrate::TetrateProvider; -use goose::model::ModelConfig; use rmcp::model::Tool; use rmcp::object; use serial_test::serial; @@ -24,45 +24,53 @@ mod tetrate_streaming_tests { #[ignore] // Ignore by default, run with --ignored flag when API key is available async fn test_tetrate_streaming_basic() -> Result<()> { let provider = create_test_provider()?; - - let messages = vec![ - Message::user().with_text("Count from 1 to 5, one number at a time.") - ]; - + + let messages = vec![Message::user().with_text("Count from 1 to 5, one number at a time.")]; + let mut stream = provider - .stream("You are a helpful assistant that counts numbers.", &messages, &[]) + .stream( + "You are a helpful assistant that counts numbers.", + &messages, + &[], + ) .await?; - + let mut chunk_count = 0; let mut content_chunks = Vec::new(); - + while let Some(result) = stream.next().await { let (message, usage) = result?; chunk_count += 1; - + if let Some(msg) = message { let text = msg.as_concat_text(); if !text.is_empty() { content_chunks.push(text); } } - + // Check if we have usage information in the final chunk if usage.is_some() { println!("Received usage information in chunk {}", chunk_count); } } - + assert!(chunk_count > 0, "Should receive at least one chunk"); assert!(!content_chunks.is_empty(), "Should receive some content"); - + let full_content = content_chunks.join(""); println!("Full streamed content: {}", full_content); - + // Verify the response contains numbers - assert!(full_content.contains('1'), "Response should contain number 1"); - assert!(full_content.contains('5'), "Response should contain number 5"); - + assert!( + full_content.contains('1'), + "Response should contain number 1" + ); + assert!( + full_content.contains('5'), + "Response should contain number 5" + ); + Ok(()) } @@ -71,7 +79,7 @@ mod tetrate_streaming_tests { #[ignore] async fn test_tetrate_streaming_with_tools() -> Result<()> { let provider = create_test_provider()?; - + // Define a simple tool let weather_tool = Tool::new( "get_weather", @@ -87,22 +95,24 @@ mod tetrate_streaming_tests { "required": ["location"] }), ); - - let messages = vec![ - Message::user().with_text("What's the weather in San Francisco?") - ]; - + + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + let mut stream = provider - .stream("You are a helpful assistant with access to weather information.", &messages, &[weather_tool]) + .stream( + "You are a helpful assistant with access to weather information.", + &messages, + &[weather_tool], + ) .await?; - + let mut received_tool_call = false; let mut chunk_count = 0; - + while let Some(result) = stream.next().await { let (message, _usage) = result?; chunk_count += 1; - + if let Some(msg) = message { // Check if message contains tool requests for content in &msg.content { @@ -113,7 +123,7 @@ mod tetrate_streaming_tests { } } } - + assert!(chunk_count > 0, "Should receive at least one chunk"); // Note: Tool calls might not be supported in streaming for all models // This is more of a capability test than a requirement @@ -122,7 +132,7 @@ mod tetrate_streaming_tests { } else { println!("⚠ Streaming with tools may not be fully supported"); } - + Ok(()) } @@ -131,26 +141,27 @@ mod tetrate_streaming_tests { #[ignore] async fn test_tetrate_streaming_empty_response() -> Result<()> { let provider = create_test_provider()?; - + // This might result in a very short or empty response - let messages = vec![ - Message::user().with_text("") - ]; - + let messages = vec![Message::user().with_text("")]; + let mut stream = provider .stream("You are a helpful assistant.", &messages, &[]) .await?; - + let mut chunk_count = 0; - + while let Some(result) = stream.next().await { let (_message, _usage) = result?; chunk_count += 1; } - + // Even with empty input, we should get at least one chunk (possibly with finish_reason) - assert!(chunk_count > 0, "Should receive at least one chunk even with empty input"); - + assert!( + chunk_count > 0, + "Should receive at least one chunk even with empty input" + ); + Ok(()) } @@ -159,40 +170,56 @@ mod tetrate_streaming_tests { #[ignore] async fn test_tetrate_streaming_long_response() -> Result<()> { let provider = create_test_provider()?; - - let messages = vec![ - Message::user().with_text("Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.") - ]; - + + let messages = vec![Message::user().with_text( + "Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.", + )]; + let mut stream = provider - .stream("You are a helpful assistant that writes detailed essays.", &messages, &[]) + .stream( + "You are a helpful assistant that writes detailed essays.", + &messages, + &[], + ) .await?; - + let mut chunk_count = 0; let mut total_content_length = 0; - + while let Some(result) = stream.next().await { let (message, usage) = result?; chunk_count += 1; - + if let Some(msg) = message { let text = msg.as_concat_text(); total_content_length += text.len(); } - + // Final chunk should have usage information if let Some(usage_info) = usage { println!("Final usage: {:?}", usage_info.usage); - assert!(usage_info.usage.output_tokens.unwrap_or(0) > 0, "Should have output tokens"); + assert!( + usage_info.usage.output_tokens.unwrap_or(0) > 0, + "Should have output tokens" + ); } } - - println!("Received {} chunks with total content length: {}", chunk_count, total_content_length); - + + println!( + "Received {} chunks with total content length: {}", + chunk_count, total_content_length + ); + // For a detailed essay, we expect multiple chunks and substantial content - assert!(chunk_count > 5, "Long response should be streamed in multiple chunks"); - assert!(total_content_length > 100, "Essay should have substantial content"); - + assert!( + chunk_count > 5, + "Long response should be streamed in multiple chunks" + ); + assert!( + total_content_length > 100, + "Essay should have substantial content" + ); + Ok(()) } @@ -201,24 +228,22 @@ mod tetrate_streaming_tests { async fn test_tetrate_streaming_error_handling() -> Result<()> { // Test with invalid API key to ensure error handling works std::env::set_var("TETRATE_API_KEY", "invalid-key-for-testing"); - + let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; let provider = TetrateProvider::from_env(model_config)?; - - let messages = vec![ - Message::user().with_text("Hello") - ]; - + + let messages = vec![Message::user().with_text("Hello")]; + let result = provider .stream("You are a helpful assistant.", &messages, &[]) .await; - + // We expect this to fail with an authentication error assert!(result.is_err(), "Should fail with invalid API key"); - + // Clean up std::env::remove_var("TETRATE_API_KEY"); - + Ok(()) } @@ -227,49 +252,55 @@ mod tetrate_streaming_tests { #[ignore] async fn test_tetrate_streaming_concurrent_streams() -> Result<()> { let provider = create_test_provider()?; - + // Create multiple concurrent streams let messages1 = vec![Message::user().with_text("Say 'Stream 1'")]; let messages2 = vec![Message::user().with_text("Say 'Stream 2'")]; - + let stream1 = provider .stream("You are a helpful assistant.", &messages1, &[]) .await?; - + let stream2 = provider .stream("You are a helpful assistant.", &messages2, &[]) .await?; - + // Process both streams concurrently let (result1, result2) = tokio::join!( process_stream(stream1, "Stream 1"), process_stream(stream2, "Stream 2") ); - + let content1 = result1?; let content2 = result2?; - + println!("Stream 1 content: {}", content1); println!("Stream 2 content: {}", content2); - - assert!(content1.contains("Stream 1") || content1.contains("1"), "First stream should mention Stream 1"); - assert!(content2.contains("Stream 2") || content2.contains("2"), "Second stream should mention Stream 2"); - + + assert!( + content1.contains("Stream 1") || content1.contains("1"), + "First stream should mention Stream 1" + ); + assert!( + content2.contains("Stream 2") || content2.contains("2"), + "Second stream should mention Stream 2" + ); + Ok(()) } // Helper function to process a stream and collect content async fn process_stream( mut stream: goose::providers::base::MessageStream, - label: &str + label: &str, ) -> Result { let mut content = String::new(); let mut chunk_count = 0; - + while let Some(result) = stream.next().await { let (message, _usage) = result?; chunk_count += 1; - + if let Some(msg) = message { let text = msg.as_concat_text(); if !text.is_empty() { @@ -277,8 +308,8 @@ mod tetrate_streaming_tests { } } } - + println!("{}: Received {} chunks", label, chunk_count); Ok(content) } -} \ No newline at end of file +}