diff --git a/clust_macros/Cargo.toml b/clust_macros/Cargo.toml index 0e87d2e..f71a250 100644 --- a/clust_macros/Cargo.toml +++ b/clust_macros/Cargo.toml @@ -18,3 +18,6 @@ quote = "1.0.*" syn = "2.0.*" proc-macro2 = "1.0.*" clust = { path = ".." } + +[dev-dependencies] +tokio = { version = "1.37.0", features = ["macros"] } diff --git a/clust_macros/src/check_result.rs b/clust_macros/src/check_result.rs new file mode 100644 index 0000000..c53c25e --- /dev/null +++ b/clust_macros/src/check_result.rs @@ -0,0 +1,69 @@ +use syn::Type; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum ReturnType { + Value, + Result, +} + +pub(crate) fn get_return_type(ty: &Type) -> ReturnType { + if is_result_type(ty) { + ReturnType::Result + } else { + ReturnType::Value + } +} + +fn is_result_type(ty: &Type) -> bool { + match ty { + | Type::Path(type_path) => { + let path_segments = &type_path.path.segments; + path_segments.last().map_or(false, |last_segment| { + if last_segment.ident == "Result" { + match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) => args.args.len() == 2, + _ => false, + } + } else if path_segments.len() >= 2 { + path_segments + .iter() + .rev() + .nth(1) + .map_or(false, |second_last_segment| { + second_last_segment.ident == "result" + && match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) => args.args.len() == 2, + _ => false, + } + }) + } else { + false + } + }) + }, + | _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + type TestResult = Result; + + #[test] + fn test_get_return_type() { + let ty = syn::parse_str::("Result").unwrap(); + assert_eq!(get_return_type(&ty), ReturnType::Result); + + let ty = + syn::parse_str::("std::result::Result").unwrap(); + assert_eq!(get_return_type(&ty), ReturnType::Result); + + let ty = syn::parse_str::("i32").unwrap(); + assert_eq!(get_return_type(&ty), ReturnType::Value); + + let ty: Type = syn::parse_str::("TestResult").unwrap(); + assert_eq!(get_return_type(&ty), ReturnType::Value); + } +} diff --git a/clust_macros/src/lib.rs b/clust_macros/src/lib.rs index 550b3fd..e6f47d2 100644 --- a/clust_macros/src/lib.rs +++ b/clust_macros/src/lib.rs @@ -1,22 +1,14 @@ -use crate::tool::{impl_tool, impl_tool_with_result}; +use crate::tool::impl_tool; use proc_macro::TokenStream; +mod check_result; mod tool; #[proc_macro_attribute] pub fn clust_tool( - attr: TokenStream, + _attr: TokenStream, item: TokenStream, ) -> TokenStream { let item_func = syn::parse::(item).unwrap(); impl_tool(&item_func) } - -#[proc_macro_attribute] -pub fn clust_tool_result( - attr: TokenStream, - item: TokenStream, -) -> TokenStream { - let item_func = syn::parse::(item).unwrap(); - impl_tool_with_result(&item_func) -} diff --git a/clust_macros/src/tool.rs b/clust_macros/src/tool.rs index a925e48..a05697b 100644 --- a/clust_macros/src/tool.rs +++ b/clust_macros/src/tool.rs @@ -4,6 +4,7 @@ use proc_macro::TokenStream; use proc_macro2::{Ident, Span}; use std::collections::BTreeMap; +use crate::check_result::{get_return_type, ReturnType}; use quote::{quote, ToTokens}; use syn::{AttrStyle, Expr, ItemFn, Meta}; @@ -364,12 +365,80 @@ fn quote_call_with_result( } } +fn quote_call_async( + func: &ItemFn, + info: &ToolInformation, +) -> proc_macro2::TokenStream { + let name = info.name.clone(); + let ident = func.sig.ident.clone(); + let parameters = quote_invoke_parameters(info); + let quote_result = quote_result(name.clone()); + + quote! { + async fn call(&self, function_calls: clust::messages::FunctionCalls) + -> std::result::Result { + if function_calls.invoke.tool_name != #name { + return Err(clust::messages::ToolCallError::ToolNameMismatch); + } + + let result = #ident( + #( + #parameters + ),* + ).await; + + #quote_result + } + } +} + +fn quote_call_async_with_result( + func: &ItemFn, + info: &ToolInformation, +) -> proc_macro2::TokenStream { + let name = info.name.clone(); + let ident = func.sig.ident.clone(); + let parameters = quote_invoke_parameters(info); + let quote_result = quote_result_with_match(name.clone()); + + quote! { + async fn call(&self, function_calls: clust::messages::FunctionCalls) + -> std::result::Result { + if function_calls.invoke.tool_name != #name { + return Err(clust::messages::ToolCallError::ToolNameMismatch); + } + + let result = #ident( + #( + #parameters + ),* + ).await; + + #quote_result + } + } +} + fn impl_tool_for_function( func: &ItemFn, info: ToolInformation, ) -> proc_macro2::TokenStream { - let description_quote = quote_description(&info); - let call_quote = quote_call(func, &info); + let impl_description = quote_description(&info); + + let impl_call = match func.sig.output.clone() { + | syn::ReturnType::Default => { + panic!("Function must have a displayable return type") + }, + | syn::ReturnType::Type(_, _type) => { + let return_type = get_return_type(&_type); + + match return_type { + | ReturnType::Value => quote_call(func, &info), + | ReturnType::Result => quote_call_with_result(func, &info), + } + }, + }; + let struct_name = Ident::new( &format!("ClustTool_{}", info.name), Span::call_site(), @@ -384,18 +453,34 @@ fn impl_tool_for_function( // Implement Tool trait for generated tool struct impl clust::messages::Tool for #struct_name { - #description_quote - #call_quote + #impl_description + #impl_call } } } -fn impl_tool_for_function_with_result( +fn impl_tool_for_async_function( func: &ItemFn, info: ToolInformation, ) -> proc_macro2::TokenStream { - let description_quote = quote_description(&info); - let call_quote = quote_call_with_result(func, &info); + let impl_description = quote_description(&info); + + let impl_call = match func.sig.output.clone() { + | syn::ReturnType::Default => { + panic!("Function must have a displayable return type") + }, + | syn::ReturnType::Type(_, _type) => { + let return_type = get_return_type(&_type); + + match return_type { + | ReturnType::Value => quote_call_async(func, &info), + | ReturnType::Result => { + quote_call_async_with_result(func, &info) + }, + } + }, + }; + let struct_name = Ident::new( &format!("ClustTool_{}", info.name), Span::call_site(), @@ -409,21 +494,21 @@ fn impl_tool_for_function_with_result( pub struct #struct_name; // Implement Tool trait for generated tool struct - impl clust::messages::Tool for #struct_name { - #description_quote - #call_quote + impl clust::messages::AsyncTool for #struct_name { + #impl_description + #impl_call } } } pub(crate) fn impl_tool(func: &ItemFn) -> TokenStream { let tool_information = get_tool_information(func); - impl_tool_for_function(func, tool_information).into() -} -pub(crate) fn impl_tool_with_result(func: &ItemFn) -> TokenStream { - let tool_information = get_tool_information(func); - impl_tool_for_function_with_result(func, tool_information).into() + if func.sig.asyncness.is_some() { + impl_tool_for_async_function(func, tool_information).into() + } else { + impl_tool_for_function(func, tool_information).into() + } } #[cfg(test)] diff --git a/clust_macros/tests/tool_async.rs b/clust_macros/tests/tool_async.rs new file mode 100644 index 0000000..962b77d --- /dev/null +++ b/clust_macros/tests/tool_async.rs @@ -0,0 +1,62 @@ +use std::collections::BTreeMap; + +use clust::messages::{AsyncTool, FunctionCalls, FunctionResults, Invoke}; + +use clust_macros::clust_tool; + +/// An asynchronous function for testing. +/// +/// ## Arguments +/// - `arg1` - First argument. +#[clust_tool] +async fn test_function(arg1: i32) -> i32 { + arg1 + 1 +} + +#[test] +fn test_description() { + let tool = ClustTool_test_function {}; + + assert_eq!( + tool.description().to_string(), + r#" + + test_function + An asynchronous function for testing. + + + arg1 + i32 + First argument. + + +"# + ); +} + +#[tokio::test] +async fn test_call() { + let tool = ClustTool_test_function {}; + + let function_calls = FunctionCalls { + invoke: Invoke { + tool_name: String::from("test_function"), + parameters: BTreeMap::from_iter(vec![( + "arg1".to_string(), + "42".to_string(), + )]), + }, + }; + + let result = tool + .call(function_calls) + .await + .unwrap(); + + if let FunctionResults::Result(result) = result { + assert_eq!(result.tool_name, "test_function"); + assert_eq!(result.stdout, "43"); + } else { + panic!("Expected FunctionResults::Result"); + } +} diff --git a/clust_macros/tests/tool_async_with_result.rs b/clust_macros/tests/tool_async_with_result.rs new file mode 100644 index 0000000..9c7dceb --- /dev/null +++ b/clust_macros/tests/tool_async_with_result.rs @@ -0,0 +1,110 @@ +use std::collections::BTreeMap; +use std::fmt::Display; + +use clust::messages::{AsyncTool, FunctionCalls, FunctionResults, Invoke}; + +use clust_macros::clust_tool; + +/// An asynchronous function with returning result for testing. +/// +/// ## Arguments +/// - `arg1` - First argument. +/// +/// ## Examples +/// ```rust +/// ``` +#[clust_tool] +async fn test_function_with_result(arg1: i32) -> Result { + if arg1 >= 0 { + Ok(arg1 as u32) + } else { + Err(TestError { + message: "arg1 is negative".to_string(), + }) + } +} + +struct TestError { + message: String, +} + +impl Display for TestError { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +#[test] +fn test_description() { + let tool = ClustTool_test_function_with_result {}; + + assert_eq!( + tool.description().to_string(), + r#" + + test_function_with_result + An asynchronous function with returning result for testing. + + + arg1 + i32 + First argument. + + +"# + ); +} + +#[tokio::test] +async fn test_call() { + let tool = ClustTool_test_function_with_result {}; + + let function_calls = FunctionCalls { + invoke: Invoke { + tool_name: String::from("test_function_with_result"), + parameters: BTreeMap::from_iter(vec![( + "arg1".to_string(), + "1".to_string(), + )]), + }, + }; + + let result = tool + .call(function_calls) + .await + .unwrap(); + + if let FunctionResults::Result(result) = result { + assert_eq!( + result.tool_name, + "test_function_with_result" + ); + assert_eq!(result.stdout, "1"); + } else { + panic!("Expected FunctionResults::Result"); + } + + let function_calls = FunctionCalls { + invoke: Invoke { + tool_name: String::from("test_function_with_result"), + parameters: BTreeMap::from_iter(vec![( + "arg1".to_string(), + "-1".to_string(), + )]), + }, + }; + + let result = tool + .call(function_calls) + .await + .unwrap(); + + if let FunctionResults::Error(error) = result { + assert_eq!(error, "arg1 is negative"); + } else { + panic!("Expected FunctionResults::Error"); + } +} diff --git a/clust_macros/tests/tool_with_result.rs b/clust_macros/tests/tool_with_result.rs index 6e22dbe..b3029bd 100644 --- a/clust_macros/tests/tool_with_result.rs +++ b/clust_macros/tests/tool_with_result.rs @@ -5,7 +5,7 @@ use clust::messages::{ FunctionCalls, FunctionResults, Invoke, Tool, }; -use clust_macros::clust_tool_result; +use clust_macros::clust_tool; /// A function with returning result for testing. /// @@ -15,7 +15,7 @@ use clust_macros::clust_tool_result; /// ## Examples /// ```rust /// ``` -#[clust_tool_result] +#[clust_tool] fn test_function_with_result(arg1: i32) -> Result { if arg1 >= 0 { Ok(arg1 as u32) diff --git a/src/messages.rs b/src/messages.rs index 8a28b42..b5ddd3f 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -38,6 +38,7 @@ pub use error::MessageChunkTypeError; pub use error::MessagesError; pub use error::StreamError; pub use error::ToolCallError; +pub use function::AsyncTool; pub use function::FunctionCalls; pub use function::FunctionResult; pub use function::FunctionResults; diff --git a/src/messages/function.rs b/src/messages/function.rs index 0bc0f29..1458192 100644 --- a/src/messages/function.rs +++ b/src/messages/function.rs @@ -38,6 +38,19 @@ pub trait Tool { ) -> Result; } +/// A tool is an asynchronous function that can be called by the assistant. +pub trait AsyncTool { + /// Returns the description of the tool. + fn description(&self) -> ToolDescription; + + /// Calls the tool with the provided function calls. + fn call( + &self, + function_calls: FunctionCalls, + ) -> impl std::future::Future> + + Send; +} + /// ## XML example /// ```xml ///